pbt_memnn_example.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. """Example training a memory neural net on the bAbI dataset.
  2. References Keras and is based off of https://keras.io/examples/babi_memnn/.
  3. """
  4. from __future__ import print_function
  5. import argparse
  6. import os
  7. import re
  8. import sys
  9. import tarfile
  10. import numpy as np
  11. from filelock import FileLock
  12. from ray import tune
  13. if sys.version_info >= (3, 12):
  14. # Skip this test in Python 3.12+ because TensorFlow is not supported.
  15. sys.exit(0)
  16. else:
  17. from tensorflow.keras.layers import (
  18. LSTM,
  19. Activation,
  20. Dense,
  21. Dropout,
  22. Embedding,
  23. Input,
  24. Permute,
  25. add,
  26. concatenate,
  27. dot,
  28. )
  29. from tensorflow.keras.models import Model, Sequential, load_model
  30. from tensorflow.keras.optimizers import RMSprop
  31. from tensorflow.keras.preprocessing.sequence import pad_sequences
  32. from tensorflow.keras.utils import get_file
  33. def tokenize(sent):
  34. """Return the tokens of a sentence including punctuation.
  35. >>> tokenize("Bob dropped the apple. Where is the apple?")
  36. ["Bob", "dropped", "the", "apple", ".", "Where", "is", "the", "apple", "?"]
  37. """
  38. return [x.strip() for x in re.split(r"(\W+)?", sent) if x and x.strip()]
  39. def parse_stories(lines, only_supporting=False):
  40. """Parse stories provided in the bAbi tasks format
  41. If only_supporting is true, only the sentences
  42. that support the answer are kept.
  43. """
  44. data = []
  45. story = []
  46. for line in lines:
  47. line = line.decode("utf-8").strip()
  48. nid, line = line.split(" ", 1)
  49. nid = int(nid)
  50. if nid == 1:
  51. story = []
  52. if "\t" in line:
  53. q, a, supporting = line.split("\t")
  54. q = tokenize(q)
  55. if only_supporting:
  56. # Only select the related substory
  57. supporting = map(int, supporting.split())
  58. substory = [story[i - 1] for i in supporting]
  59. else:
  60. # Provide all the substories
  61. substory = [x for x in story if x]
  62. data.append((substory, q, a))
  63. story.append("")
  64. else:
  65. sent = tokenize(line)
  66. story.append(sent)
  67. return data
  68. def get_stories(f, only_supporting=False, max_length=None):
  69. """Given a file name, read the file,
  70. retrieve the stories,
  71. and then convert the sentences into a single story.
  72. If max_length is supplied,
  73. any stories longer than max_length tokens will be discarded.
  74. """
  75. def flatten(data):
  76. return sum(data, [])
  77. data = parse_stories(f.readlines(), only_supporting=only_supporting)
  78. data = [
  79. (flatten(story), q, answer)
  80. for story, q, answer in data
  81. if not max_length or len(flatten(story)) < max_length
  82. ]
  83. return data
  84. def vectorize_stories(word_idx, story_maxlen, query_maxlen, data):
  85. inputs, queries, answers = [], [], []
  86. for story, query, answer in data:
  87. inputs.append([word_idx[w] for w in story])
  88. queries.append([word_idx[w] for w in query])
  89. answers.append(word_idx[answer])
  90. return (
  91. pad_sequences(inputs, maxlen=story_maxlen),
  92. pad_sequences(queries, maxlen=query_maxlen),
  93. np.array(answers),
  94. )
  95. def read_data(finish_fast=False):
  96. # Get the file
  97. try:
  98. path = get_file(
  99. "babi-tasks-v1-2.tar.gz",
  100. origin="https://s3.amazonaws.com/text-datasets/"
  101. "babi_tasks_1-20_v1-2.tar.gz",
  102. )
  103. except Exception:
  104. print(
  105. "Error downloading dataset, please download it manually:\n"
  106. "$ wget http://www.thespermwhale.com/jaseweston/babi/tasks_1-20_v1-2" # noqa: E501
  107. ".tar.gz\n"
  108. "$ mv tasks_1-20_v1-2.tar.gz ~/.keras/datasets/babi-tasks-v1-2.tar.gz" # noqa: E501
  109. )
  110. raise
  111. # Choose challenge
  112. challenges = {
  113. # QA1 with 10,000 samples
  114. "single_supporting_fact_10k": "tasks_1-20_v1-2/en-10k/qa1_"
  115. "single-supporting-fact_{}.txt",
  116. # QA2 with 10,000 samples
  117. "two_supporting_facts_10k": "tasks_1-20_v1-2/en-10k/qa2_"
  118. "two-supporting-facts_{}.txt",
  119. }
  120. challenge_type = "single_supporting_fact_10k"
  121. challenge = challenges[challenge_type]
  122. with tarfile.open(path) as tar:
  123. train_stories = get_stories(tar.extractfile(challenge.format("train")))
  124. test_stories = get_stories(tar.extractfile(challenge.format("test")))
  125. if finish_fast:
  126. train_stories = train_stories[:64]
  127. test_stories = test_stories[:64]
  128. return train_stories, test_stories
  129. class MemNNModel(tune.Trainable):
  130. def build_model(self):
  131. """Helper method for creating the model"""
  132. vocab = set()
  133. for story, q, answer in self.train_stories + self.test_stories:
  134. vocab |= set(story + q + [answer])
  135. vocab = sorted(vocab)
  136. # Reserve 0 for masking via pad_sequences
  137. vocab_size = len(vocab) + 1
  138. story_maxlen = max(len(x) for x, _, _ in self.train_stories + self.test_stories)
  139. query_maxlen = max(len(x) for _, x, _ in self.train_stories + self.test_stories)
  140. word_idx = {c: i + 1 for i, c in enumerate(vocab)}
  141. self.inputs_train, self.queries_train, self.answers_train = vectorize_stories(
  142. word_idx, story_maxlen, query_maxlen, self.train_stories
  143. )
  144. self.inputs_test, self.queries_test, self.answers_test = vectorize_stories(
  145. word_idx, story_maxlen, query_maxlen, self.test_stories
  146. )
  147. # placeholders
  148. input_sequence = Input((story_maxlen,))
  149. question = Input((query_maxlen,))
  150. # encoders
  151. # embed the input sequence into a sequence of vectors
  152. input_encoder_m = Sequential()
  153. input_encoder_m.add(Embedding(input_dim=vocab_size, output_dim=64))
  154. input_encoder_m.add(Dropout(self.config.get("dropout", 0.3)))
  155. # output: (samples, story_maxlen, embedding_dim)
  156. # embed the input into a sequence of vectors of size query_maxlen
  157. input_encoder_c = Sequential()
  158. input_encoder_c.add(Embedding(input_dim=vocab_size, output_dim=query_maxlen))
  159. input_encoder_c.add(Dropout(self.config.get("dropout", 0.3)))
  160. # output: (samples, story_maxlen, query_maxlen)
  161. # embed the question into a sequence of vectors
  162. question_encoder = Sequential()
  163. question_encoder.add(
  164. Embedding(input_dim=vocab_size, output_dim=64, input_length=query_maxlen)
  165. )
  166. question_encoder.add(Dropout(self.config.get("dropout", 0.3)))
  167. # output: (samples, query_maxlen, embedding_dim)
  168. # encode input sequence and questions (which are indices)
  169. # to sequences of dense vectors
  170. input_encoded_m = input_encoder_m(input_sequence)
  171. input_encoded_c = input_encoder_c(input_sequence)
  172. question_encoded = question_encoder(question)
  173. # compute a "match" between the first input vector sequence
  174. # and the question vector sequence
  175. # shape: `(samples, story_maxlen, query_maxlen)`
  176. match = dot([input_encoded_m, question_encoded], axes=(2, 2))
  177. match = Activation("softmax")(match)
  178. # add the match matrix with the second input vector sequence
  179. response = add(
  180. [match, input_encoded_c]
  181. ) # (samples, story_maxlen, query_maxlen)
  182. response = Permute((2, 1))(response) # (samples, query_maxlen, story_maxlen)
  183. # concatenate the match matrix with the question vector sequence
  184. answer = concatenate([response, question_encoded])
  185. # the original paper uses a matrix multiplication.
  186. # we choose to use a RNN instead.
  187. answer = LSTM(32)(answer) # (samples, 32)
  188. # one regularization layer -- more would probably be needed.
  189. answer = Dropout(self.config.get("dropout", 0.3))(answer)
  190. answer = Dense(vocab_size)(answer) # (samples, vocab_size)
  191. # we output a probability distribution over the vocabulary
  192. answer = Activation("softmax")(answer)
  193. # build the final model
  194. model = Model([input_sequence, question], answer)
  195. return model
  196. def setup(self, config):
  197. with FileLock(os.path.expanduser("~/.tune.lock")):
  198. self.train_stories, self.test_stories = read_data(config["finish_fast"])
  199. model = self.build_model()
  200. rmsprop = RMSprop(
  201. lr=self.config.get("lr", 1e-3), rho=self.config.get("rho", 0.9)
  202. )
  203. model.compile(
  204. optimizer=rmsprop,
  205. loss="sparse_categorical_crossentropy",
  206. metrics=["accuracy"],
  207. )
  208. self.model = model
  209. def step(self):
  210. # train
  211. self.model.fit(
  212. [self.inputs_train, self.queries_train],
  213. self.answers_train,
  214. batch_size=self.config.get("batch_size", 32),
  215. epochs=self.config.get("epochs", 1),
  216. validation_data=([self.inputs_test, self.queries_test], self.answers_test),
  217. verbose=0,
  218. )
  219. _, accuracy = self.model.evaluate(
  220. [self.inputs_train, self.queries_train], self.answers_train, verbose=0
  221. )
  222. return {"mean_accuracy": accuracy}
  223. def save_checkpoint(self, checkpoint_dir):
  224. file_path = checkpoint_dir + "/model"
  225. self.model.save(file_path)
  226. def load_checkpoint(self, checkpoint_dir):
  227. # See https://stackoverflow.com/a/42763323
  228. del self.model
  229. file_path = checkpoint_dir + "/model"
  230. self.model = load_model(file_path)
  231. if __name__ == "__main__":
  232. import ray
  233. from ray.tune.schedulers import PopulationBasedTraining
  234. parser = argparse.ArgumentParser()
  235. parser.add_argument(
  236. "--smoke-test", action="store_true", help="Finish quickly for testing"
  237. )
  238. args, _ = parser.parse_known_args()
  239. if args.smoke_test:
  240. ray.init(num_cpus=2)
  241. perturbation_interval = 2
  242. pbt = PopulationBasedTraining(
  243. perturbation_interval=perturbation_interval,
  244. hyperparam_mutations={
  245. "dropout": lambda: np.random.uniform(0, 1),
  246. "lr": lambda: 10 ** np.random.randint(-10, 0),
  247. "rho": lambda: np.random.uniform(0, 1),
  248. },
  249. )
  250. tuner = tune.Tuner(
  251. MemNNModel,
  252. run_config=tune.RunConfig(
  253. name="pbt_babi_memnn",
  254. stop={"training_iteration": 4 if args.smoke_test else 100},
  255. checkpoint_config=tune.CheckpointConfig(
  256. checkpoint_frequency=perturbation_interval,
  257. checkpoint_score_attribute="mean_accuracy",
  258. num_to_keep=2,
  259. ),
  260. ),
  261. tune_config=tune.TuneConfig(
  262. scheduler=pbt,
  263. metric="mean_accuracy",
  264. mode="max",
  265. num_samples=2,
  266. reuse_actors=True,
  267. ),
  268. param_space={
  269. "finish_fast": args.smoke_test,
  270. "batch_size": 32,
  271. "epochs": 1,
  272. "dropout": 0.3,
  273. "lr": 0.01,
  274. "rho": 0.9,
  275. },
  276. )
  277. tuner.fit()