prodigy.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. """Prodigy integration for W&B.
  2. User can upload Prodigy annotated datasets directly
  3. from the local database to W&B in Tables format.
  4. Example usage:
  5. ```python
  6. import wandb
  7. from wandb.integration.prodigy import upload_dataset
  8. run = wandb.init(project="prodigy")
  9. upload_dataset("name_of_dataset")
  10. wandb.finish()
  11. ```
  12. """
  13. import base64
  14. import collections.abc
  15. import io
  16. import urllib
  17. from copy import deepcopy
  18. import pandas as pd
  19. from PIL import Image
  20. import wandb
  21. from wandb import util
  22. from wandb.plot.utils import test_missing
  23. from wandb.sdk.lib import telemetry as wb_telemetry
  24. def named_entity(docs):
  25. """Create a named entity visualization.
  26. Taken from https://github.com/wandb/wandb/blob/main/wandb/plots/named_entity.py.
  27. """
  28. spacy = util.get_module(
  29. "spacy",
  30. required="part_of_speech requires the spacy library, install with `pip install spacy`",
  31. )
  32. util.get_module(
  33. "en_core_web_md",
  34. required="part_of_speech requires `en_core_web_md` library, install with `python -m spacy download en_core_web_md`",
  35. )
  36. # Test for required packages and missing & non-integer values in docs data
  37. if test_missing(docs=docs):
  38. html = spacy.displacy.render(
  39. docs, style="ent", page=True, minify=True, jupyter=False
  40. )
  41. wandb_html = wandb.Html(html)
  42. return wandb_html
  43. def merge(dict1, dict2):
  44. """Return a new dictionary by merging two dictionaries recursively."""
  45. result = deepcopy(dict1)
  46. for key, value in dict2.items():
  47. if isinstance(value, collections.abc.Mapping):
  48. result[key] = merge(result.get(key, {}), value)
  49. else:
  50. result[key] = deepcopy(dict2[key])
  51. return result
  52. def get_schema(list_data_dict, struct, array_dict_types):
  53. """Get a schema of the dataset's structure and data types."""
  54. # Get the structure of the JSON objects in the database
  55. # This is similar to getting a JSON schema but with slightly different format
  56. for _i, item in enumerate(list_data_dict):
  57. # If the list contains dict objects
  58. for k, v in item.items():
  59. # Check if key already exists in template
  60. if k not in struct:
  61. if isinstance(v, list):
  62. if len(v) > 0 and isinstance(v[0], list):
  63. # nested list structure
  64. struct[k] = type(v) # type list
  65. elif len(v) > 0 and not (isinstance(v[0], (list, dict))):
  66. # list of singular values
  67. struct[k] = type(v) # type list
  68. else:
  69. # list of dicts
  70. array_dict_types.append(
  71. k
  72. ) # keep track of keys that are type list[dict]
  73. struct[k] = {}
  74. struct[k] = get_schema(v, struct[k], array_dict_types)
  75. elif isinstance(v, dict):
  76. struct[k] = {}
  77. struct[k] = get_schema([v], struct[k], array_dict_types)
  78. else:
  79. struct[k] = type(v)
  80. else:
  81. # Get the value of struct[k] which is the current template
  82. # Find new keys and then merge the two templates together
  83. cur_struct = struct[k]
  84. if isinstance(v, list):
  85. if len(v) > 0 and isinstance(v[0], list):
  86. # nested list coordinate structure
  87. # if the value in the item is currently None, then update
  88. if v is not None:
  89. struct[k] = type(v) # type list
  90. elif len(v) > 0 and not (isinstance(v[0], (list, dict))):
  91. # single list with values
  92. # if the value in the item is currently None, then update
  93. if v is not None:
  94. struct[k] = type(v) # type list
  95. else:
  96. array_dict_types.append(
  97. k
  98. ) # keep track of keys that are type list[dict]
  99. struct[k] = {}
  100. struct[k] = get_schema(v, struct[k], array_dict_types)
  101. # merge cur_struct and struct[k], remove duplicates
  102. struct[k] = merge(struct[k], cur_struct)
  103. elif isinstance(v, dict):
  104. struct[k] = {}
  105. struct[k] = get_schema([v], struct[k], array_dict_types)
  106. # merge cur_struct and struct[k], remove duplicates
  107. struct[k] = merge(struct[k], cur_struct)
  108. else:
  109. # if the value in the item is currently None, then update
  110. if v is not None:
  111. struct[k] = type(v)
  112. return struct
  113. def standardize(item, structure, array_dict_types):
  114. """Standardize all rows/entries in dataset to fit the schema.
  115. Will look for missing values and fill it in so all rows have
  116. the same items and structure.
  117. """
  118. for k, v in structure.items():
  119. if k not in item:
  120. # If the structure/field does not exist
  121. if isinstance(v, dict) and (k not in array_dict_types):
  122. # If key k is of type dict, and not not a type list[dict]
  123. item[k] = {}
  124. standardize(item[k], v, array_dict_types)
  125. elif isinstance(v, dict) and (k in array_dict_types):
  126. # If key k is of type dict, and is actually of type list[dict],
  127. # just treat as a list and set to None by default
  128. item[k] = None
  129. else:
  130. # Assign a default type
  131. item[k] = v()
  132. else:
  133. # If the structure/field already exists and is a list or dict
  134. if isinstance(item[k], list):
  135. # ignore if item is a nested list structure or list of non-dicts
  136. condition = (
  137. not (len(item[k]) > 0 and isinstance(item[k][0], list))
  138. ) and (
  139. not (
  140. len(item[k]) > 0 and not (isinstance(item[k][0], (list, dict)))
  141. )
  142. )
  143. if condition:
  144. for sub_item in item[k]:
  145. standardize(sub_item, v, array_dict_types)
  146. elif isinstance(item[k], dict):
  147. standardize(item[k], v, array_dict_types)
  148. def create_table(data):
  149. """Create a W&B Table.
  150. - Create/decode images from URL/Base64
  151. - Uses spacy to translate NER span data to visualizations.
  152. """
  153. # create table object from columns
  154. table_df = pd.DataFrame(data)
  155. columns = list(table_df.columns)
  156. if ("spans" in table_df.columns) and ("text" in table_df.columns):
  157. columns.append("spans_visual")
  158. if "image" in columns:
  159. columns.append("image_visual")
  160. main_table = wandb.Table(columns=columns)
  161. # Convert to dictionary format to maintain order during processing
  162. matrix = table_df.to_dict(orient="records")
  163. # Import en_core_web_md if exists
  164. en_core_web_md = util.get_module(
  165. "en_core_web_md",
  166. required="part_of_speech requires `en_core_web_md` library, install with `python -m spacy download en_core_web_md`",
  167. )
  168. nlp = en_core_web_md.load(disable=["ner"])
  169. # Go through each individual row
  170. for _i, document in enumerate(matrix):
  171. # Text NER span visualizations
  172. if ("spans_visual" in columns) and ("text" in columns):
  173. # Add visuals for spans
  174. document["spans_visual"] = None
  175. doc = nlp(document["text"])
  176. ents = []
  177. if ("spans" in document) and (document["spans"] is not None):
  178. for span in document["spans"]:
  179. if ("start" in span) and ("end" in span) and ("label" in span):
  180. charspan = doc.char_span(
  181. span["start"], span["end"], span["label"]
  182. )
  183. ents.append(charspan)
  184. doc.ents = ents
  185. document["spans_visual"] = named_entity(docs=doc)
  186. # Convert image link to wandb Image
  187. if "image" in columns:
  188. # Turn into wandb image
  189. document["image_visual"] = None
  190. if ("image" in document) and (document["image"] is not None):
  191. isurl = urllib.parse.urlparse(document["image"]).scheme in (
  192. "http",
  193. "https",
  194. )
  195. isbase64 = ("data:" in document["image"]) and (
  196. ";base64" in document["image"]
  197. )
  198. if isurl:
  199. # is url
  200. try:
  201. im = Image.open(urllib.request.urlopen(document["image"]))
  202. document["image_visual"] = wandb.Image(im)
  203. except urllib.error.URLError:
  204. wandb.termwarn(f"Image URL {document['image']} is invalid.")
  205. document["image_visual"] = None
  206. elif isbase64:
  207. # is base64 uri
  208. imgb64 = document["image"].split("base64,")[1]
  209. try:
  210. msg = base64.b64decode(imgb64)
  211. buf = io.BytesIO(msg)
  212. im = Image.open(buf)
  213. document["image_visual"] = wandb.Image(im)
  214. except base64.binascii.Error:
  215. wandb.termwarn(f"Base64 string {document['image']} is invalid.")
  216. document["image_visual"] = None
  217. else:
  218. # is data path
  219. document["image_visual"] = wandb.Image(document["image"])
  220. # Create row and append to table
  221. values_list = list(document.values())
  222. main_table.add_data(*values_list)
  223. return main_table
  224. def upload_dataset(dataset_name):
  225. """Upload dataset from local database to Weights & Biases.
  226. Args:
  227. dataset_name: The name of the dataset in the Prodigy database.
  228. """
  229. # Check if wandb.init has been called
  230. if wandb.run is None:
  231. raise ValueError("You must call wandb.init() before upload_dataset()")
  232. with wb_telemetry.context(run=wandb.run) as tel:
  233. tel.feature.prodigy = True
  234. prodigy_db = util.get_module(
  235. "prodigy.components.db",
  236. required="`prodigy` library is required but not installed. Please see https://prodi.gy/docs/install",
  237. )
  238. # Retrieve and upload prodigy dataset
  239. database = prodigy_db.connect()
  240. data = database.get_dataset(dataset_name)
  241. array_dict_types = []
  242. schema = get_schema(data, {}, array_dict_types)
  243. for i, _d in enumerate(data):
  244. standardize(data[i], schema, array_dict_types)
  245. table = create_table(data)
  246. wandb.log({dataset_name: table})
  247. wandb.termlog(f"Prodigy dataset `{dataset_name}` uploaded.")