| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- """Example training a memory neural net on the bAbI dataset.
- References Keras and is based off of https://keras.io/examples/babi_memnn/.
- """
- from __future__ import print_function
- import argparse
- import os
- import re
- import sys
- import tarfile
- import numpy as np
- from filelock import FileLock
- from ray import tune
- if sys.version_info >= (3, 12):
- # Skip this test in Python 3.12+ because TensorFlow is not supported.
- sys.exit(0)
- else:
- from tensorflow.keras.layers import (
- LSTM,
- Activation,
- Dense,
- Dropout,
- Embedding,
- Input,
- Permute,
- add,
- concatenate,
- dot,
- )
- from tensorflow.keras.models import Model, Sequential, load_model
- from tensorflow.keras.optimizers import RMSprop
- from tensorflow.keras.preprocessing.sequence import pad_sequences
- from tensorflow.keras.utils import get_file
- def tokenize(sent):
- """Return the tokens of a sentence including punctuation.
- >>> tokenize("Bob dropped the apple. Where is the apple?")
- ["Bob", "dropped", "the", "apple", ".", "Where", "is", "the", "apple", "?"]
- """
- return [x.strip() for x in re.split(r"(\W+)?", sent) if x and x.strip()]
- def parse_stories(lines, only_supporting=False):
- """Parse stories provided in the bAbi tasks format
- If only_supporting is true, only the sentences
- that support the answer are kept.
- """
- data = []
- story = []
- for line in lines:
- line = line.decode("utf-8").strip()
- nid, line = line.split(" ", 1)
- nid = int(nid)
- if nid == 1:
- story = []
- if "\t" in line:
- q, a, supporting = line.split("\t")
- q = tokenize(q)
- if only_supporting:
- # Only select the related substory
- supporting = map(int, supporting.split())
- substory = [story[i - 1] for i in supporting]
- else:
- # Provide all the substories
- substory = [x for x in story if x]
- data.append((substory, q, a))
- story.append("")
- else:
- sent = tokenize(line)
- story.append(sent)
- return data
- def get_stories(f, only_supporting=False, max_length=None):
- """Given a file name, read the file,
- retrieve the stories,
- and then convert the sentences into a single story.
- If max_length is supplied,
- any stories longer than max_length tokens will be discarded.
- """
- def flatten(data):
- return sum(data, [])
- data = parse_stories(f.readlines(), only_supporting=only_supporting)
- data = [
- (flatten(story), q, answer)
- for story, q, answer in data
- if not max_length or len(flatten(story)) < max_length
- ]
- return data
- def vectorize_stories(word_idx, story_maxlen, query_maxlen, data):
- inputs, queries, answers = [], [], []
- for story, query, answer in data:
- inputs.append([word_idx[w] for w in story])
- queries.append([word_idx[w] for w in query])
- answers.append(word_idx[answer])
- return (
- pad_sequences(inputs, maxlen=story_maxlen),
- pad_sequences(queries, maxlen=query_maxlen),
- np.array(answers),
- )
- def read_data(finish_fast=False):
- # Get the file
- try:
- path = get_file(
- "babi-tasks-v1-2.tar.gz",
- origin="https://s3.amazonaws.com/text-datasets/"
- "babi_tasks_1-20_v1-2.tar.gz",
- )
- except Exception:
- print(
- "Error downloading dataset, please download it manually:\n"
- "$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2" # noqa: E501
- ".tar.gz\n"
- "$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz" # noqa: E501
- )
- raise
- # Choose challenge
- challenges = {
- # QA1 with 10,000 samples
- "single_supporting_fact_10k": "tasks_1-20_v1-2/en-10k/qa1_"
- "single-supporting-fact_{}.txt",
- # QA2 with 10,000 samples
- "two_supporting_facts_10k": "tasks_1-20_v1-2/en-10k/qa2_"
- "two-supporting-facts_{}.txt",
- }
- challenge_type = "single_supporting_fact_10k"
- challenge = challenges[challenge_type]
- with tarfile.open(path) as tar:
- train_stories = get_stories(tar.extractfile(challenge.format("train")))
- test_stories = get_stories(tar.extractfile(challenge.format("test")))
- if finish_fast:
- train_stories = train_stories[:64]
- test_stories = test_stories[:64]
- return train_stories, test_stories
- class MemNNModel(tune.Trainable):
- def build_model(self):
- """Helper method for creating the model"""
- vocab = set()
- for story, q, answer in self.train_stories + self.test_stories:
- vocab |= set(story + q + [answer])
- vocab = sorted(vocab)
- # Reserve 0 for masking via pad_sequences
- vocab_size = len(vocab) + 1
- story_maxlen = max(len(x) for x, _, _ in self.train_stories + self.test_stories)
- query_maxlen = max(len(x) for _, x, _ in self.train_stories + self.test_stories)
- word_idx = {c: i + 1 for i, c in enumerate(vocab)}
- self.inputs_train, self.queries_train, self.answers_train = vectorize_stories(
- word_idx, story_maxlen, query_maxlen, self.train_stories
- )
- self.inputs_test, self.queries_test, self.answers_test = vectorize_stories(
- word_idx, story_maxlen, query_maxlen, self.test_stories
- )
- # placeholders
- input_sequence = Input((story_maxlen,))
- question = Input((query_maxlen,))
- # encoders
- # embed the input sequence into a sequence of vectors
- input_encoder_m = Sequential()
- input_encoder_m.add(Embedding(input_dim=vocab_size, output_dim=64))
- input_encoder_m.add(Dropout(self.config.get("dropout", 0.3)))
- # output: (samples, story_maxlen, embedding_dim)
- # embed the input into a sequence of vectors of size query_maxlen
- input_encoder_c = Sequential()
- input_encoder_c.add(Embedding(input_dim=vocab_size, output_dim=query_maxlen))
- input_encoder_c.add(Dropout(self.config.get("dropout", 0.3)))
- # output: (samples, story_maxlen, query_maxlen)
- # embed the question into a sequence of vectors
- question_encoder = Sequential()
- question_encoder.add(
- Embedding(input_dim=vocab_size, output_dim=64, input_length=query_maxlen)
- )
- question_encoder.add(Dropout(self.config.get("dropout", 0.3)))
- # output: (samples, query_maxlen, embedding_dim)
- # encode input sequence and questions (which are indices)
- # to sequences of dense vectors
- input_encoded_m = input_encoder_m(input_sequence)
- input_encoded_c = input_encoder_c(input_sequence)
- question_encoded = question_encoder(question)
- # compute a "match" between the first input vector sequence
- # and the question vector sequence
- # shape: `(samples, story_maxlen, query_maxlen)`
- match = dot([input_encoded_m, question_encoded], axes=(2, 2))
- match = Activation("softmax")(match)
- # add the match matrix with the second input vector sequence
- response = add(
- [match, input_encoded_c]
- ) # (samples, story_maxlen, query_maxlen)
- response = Permute((2, 1))(response) # (samples, query_maxlen, story_maxlen)
- # concatenate the match matrix with the question vector sequence
- answer = concatenate([response, question_encoded])
- # the original paper uses a matrix multiplication.
- # we choose to use a RNN instead.
- answer = LSTM(32)(answer) # (samples, 32)
- # one regularization layer -- more would probably be needed.
- answer = Dropout(self.config.get("dropout", 0.3))(answer)
- answer = Dense(vocab_size)(answer) # (samples, vocab_size)
- # we output a probability distribution over the vocabulary
- answer = Activation("softmax")(answer)
- # build the final model
- model = Model([input_sequence, question], answer)
- return model
- def setup(self, config):
- with FileLock(os.path.expanduser("~/.tune.lock")):
- self.train_stories, self.test_stories = read_data(config["finish_fast"])
- model = self.build_model()
- rmsprop = RMSprop(
- lr=self.config.get("lr", 1e-3), rho=self.config.get("rho", 0.9)
- )
- model.compile(
- optimizer=rmsprop,
- loss="sparse_categorical_crossentropy",
- metrics=["accuracy"],
- )
- self.model = model
- def step(self):
- # train
- self.model.fit(
- [self.inputs_train, self.queries_train],
- self.answers_train,
- batch_size=self.config.get("batch_size", 32),
- epochs=self.config.get("epochs", 1),
- validation_data=([self.inputs_test, self.queries_test], self.answers_test),
- verbose=0,
- )
- _, accuracy = self.model.evaluate(
- [self.inputs_train, self.queries_train], self.answers_train, verbose=0
- )
- return {"mean_accuracy": accuracy}
- def save_checkpoint(self, checkpoint_dir):
- file_path = checkpoint_dir + "/model"
- self.model.save(file_path)
- def load_checkpoint(self, checkpoint_dir):
- # See https://stackoverflow.com/a/42763323
- del self.model
- file_path = checkpoint_dir + "/model"
- self.model = load_model(file_path)
- if __name__ == "__main__":
- import ray
- from ray.tune.schedulers import PopulationBasedTraining
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--smoke-test", action="store_true", help="Finish quickly for testing"
- )
- args, _ = parser.parse_known_args()
- if args.smoke_test:
- ray.init(num_cpus=2)
- perturbation_interval = 2
- pbt = PopulationBasedTraining(
- perturbation_interval=perturbation_interval,
- hyperparam_mutations={
- "dropout": lambda: np.random.uniform(0, 1),
- "lr": lambda: 10 ** np.random.randint(-10, 0),
- "rho": lambda: np.random.uniform(0, 1),
- },
- )
- tuner = tune.Tuner(
- MemNNModel,
- run_config=tune.RunConfig(
- name="pbt_babi_memnn",
- stop={"training_iteration": 4 if args.smoke_test else 100},
- checkpoint_config=tune.CheckpointConfig(
- checkpoint_frequency=perturbation_interval,
- checkpoint_score_attribute="mean_accuracy",
- num_to_keep=2,
- ),
- ),
- tune_config=tune.TuneConfig(
- scheduler=pbt,
- metric="mean_accuracy",
- mode="max",
- num_samples=2,
- reuse_actors=True,
- ),
- param_space={
- "finish_fast": args.smoke_test,
- "batch_size": 32,
- "epochs": 1,
- "dropout": 0.3,
- "lr": 0.01,
- "rho": 0.9,
- },
- )
- tuner.fit()
|