| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- """Prodigy integration for W&B.
- User can upload Prodigy annotated datasets directly
- from the local database to W&B in Tables format.
- Example usage:
- ```python
- import wandb
- from wandb.integration.prodigy import upload_dataset
- run = wandb.init(project="prodigy")
- upload_dataset("name_of_dataset")
- wandb.finish()
- ```
- """
- import base64
- import collections.abc
- import io
- import urllib
- from copy import deepcopy
- import pandas as pd
- from PIL import Image
- import wandb
- from wandb import util
- from wandb.plot.utils import test_missing
- from wandb.sdk.lib import telemetry as wb_telemetry
- def named_entity(docs):
- """Create a named entity visualization.
- Taken from https://github.com/wandb/wandb/blob/main/wandb/plots/named_entity.py.
- """
- spacy = util.get_module(
- "spacy",
- required="part_of_speech requires the spacy library, install with `pip install spacy`",
- )
- util.get_module(
- "en_core_web_md",
- required="part_of_speech requires `en_core_web_md` library, install with `python -m spacy download en_core_web_md`",
- )
- # Test for required packages and missing & non-integer values in docs data
- if test_missing(docs=docs):
- html = spacy.displacy.render(
- docs, style="ent", page=True, minify=True, jupyter=False
- )
- wandb_html = wandb.Html(html)
- return wandb_html
- def merge(dict1, dict2):
- """Return a new dictionary by merging two dictionaries recursively."""
- result = deepcopy(dict1)
- for key, value in dict2.items():
- if isinstance(value, collections.abc.Mapping):
- result[key] = merge(result.get(key, {}), value)
- else:
- result[key] = deepcopy(dict2[key])
- return result
- def get_schema(list_data_dict, struct, array_dict_types):
- """Get a schema of the dataset's structure and data types."""
- # Get the structure of the JSON objects in the database
- # This is similar to getting a JSON schema but with slightly different format
- for _i, item in enumerate(list_data_dict):
- # If the list contains dict objects
- for k, v in item.items():
- # Check if key already exists in template
- if k not in struct:
- if isinstance(v, list):
- if len(v) > 0 and isinstance(v[0], list):
- # nested list structure
- struct[k] = type(v) # type list
- elif len(v) > 0 and not (isinstance(v[0], (list, dict))):
- # list of singular values
- struct[k] = type(v) # type list
- else:
- # list of dicts
- array_dict_types.append(
- k
- ) # keep track of keys that are type list[dict]
- struct[k] = {}
- struct[k] = get_schema(v, struct[k], array_dict_types)
- elif isinstance(v, dict):
- struct[k] = {}
- struct[k] = get_schema([v], struct[k], array_dict_types)
- else:
- struct[k] = type(v)
- else:
- # Get the value of struct[k] which is the current template
- # Find new keys and then merge the two templates together
- cur_struct = struct[k]
- if isinstance(v, list):
- if len(v) > 0 and isinstance(v[0], list):
- # nested list coordinate structure
- # if the value in the item is currently None, then update
- if v is not None:
- struct[k] = type(v) # type list
- elif len(v) > 0 and not (isinstance(v[0], (list, dict))):
- # single list with values
- # if the value in the item is currently None, then update
- if v is not None:
- struct[k] = type(v) # type list
- else:
- array_dict_types.append(
- k
- ) # keep track of keys that are type list[dict]
- struct[k] = {}
- struct[k] = get_schema(v, struct[k], array_dict_types)
- # merge cur_struct and struct[k], remove duplicates
- struct[k] = merge(struct[k], cur_struct)
- elif isinstance(v, dict):
- struct[k] = {}
- struct[k] = get_schema([v], struct[k], array_dict_types)
- # merge cur_struct and struct[k], remove duplicates
- struct[k] = merge(struct[k], cur_struct)
- else:
- # if the value in the item is currently None, then update
- if v is not None:
- struct[k] = type(v)
- return struct
- def standardize(item, structure, array_dict_types):
- """Standardize all rows/entries in dataset to fit the schema.
- Will look for missing values and fill it in so all rows have
- the same items and structure.
- """
- for k, v in structure.items():
- if k not in item:
- # If the structure/field does not exist
- if isinstance(v, dict) and (k not in array_dict_types):
- # If key k is of type dict, and not not a type list[dict]
- item[k] = {}
- standardize(item[k], v, array_dict_types)
- elif isinstance(v, dict) and (k in array_dict_types):
- # If key k is of type dict, and is actually of type list[dict],
- # just treat as a list and set to None by default
- item[k] = None
- else:
- # Assign a default type
- item[k] = v()
- else:
- # If the structure/field already exists and is a list or dict
- if isinstance(item[k], list):
- # ignore if item is a nested list structure or list of non-dicts
- condition = (
- not (len(item[k]) > 0 and isinstance(item[k][0], list))
- ) and (
- not (
- len(item[k]) > 0 and not (isinstance(item[k][0], (list, dict)))
- )
- )
- if condition:
- for sub_item in item[k]:
- standardize(sub_item, v, array_dict_types)
- elif isinstance(item[k], dict):
- standardize(item[k], v, array_dict_types)
- def create_table(data):
- """Create a W&B Table.
- - Create/decode images from URL/Base64
- - Uses spacy to translate NER span data to visualizations.
- """
- # create table object from columns
- table_df = pd.DataFrame(data)
- columns = list(table_df.columns)
- if ("spans" in table_df.columns) and ("text" in table_df.columns):
- columns.append("spans_visual")
- if "image" in columns:
- columns.append("image_visual")
- main_table = wandb.Table(columns=columns)
- # Convert to dictionary format to maintain order during processing
- matrix = table_df.to_dict(orient="records")
- # Import en_core_web_md if exists
- en_core_web_md = util.get_module(
- "en_core_web_md",
- required="part_of_speech requires `en_core_web_md` library, install with `python -m spacy download en_core_web_md`",
- )
- nlp = en_core_web_md.load(disable=["ner"])
- # Go through each individual row
- for _i, document in enumerate(matrix):
- # Text NER span visualizations
- if ("spans_visual" in columns) and ("text" in columns):
- # Add visuals for spans
- document["spans_visual"] = None
- doc = nlp(document["text"])
- ents = []
- if ("spans" in document) and (document["spans"] is not None):
- for span in document["spans"]:
- if ("start" in span) and ("end" in span) and ("label" in span):
- charspan = doc.char_span(
- span["start"], span["end"], span["label"]
- )
- ents.append(charspan)
- doc.ents = ents
- document["spans_visual"] = named_entity(docs=doc)
- # Convert image link to wandb Image
- if "image" in columns:
- # Turn into wandb image
- document["image_visual"] = None
- if ("image" in document) and (document["image"] is not None):
- isurl = urllib.parse.urlparse(document["image"]).scheme in (
- "http",
- "https",
- )
- isbase64 = ("data:" in document["image"]) and (
- ";base64" in document["image"]
- )
- if isurl:
- # is url
- try:
- im = Image.open(urllib.request.urlopen(document["image"]))
- document["image_visual"] = wandb.Image(im)
- except urllib.error.URLError:
- wandb.termwarn(f"Image URL {document['image']} is invalid.")
- document["image_visual"] = None
- elif isbase64:
- # is base64 uri
- imgb64 = document["image"].split("base64,")[1]
- try:
- msg = base64.b64decode(imgb64)
- buf = io.BytesIO(msg)
- im = Image.open(buf)
- document["image_visual"] = wandb.Image(im)
- except base64.binascii.Error:
- wandb.termwarn(f"Base64 string {document['image']} is invalid.")
- document["image_visual"] = None
- else:
- # is data path
- document["image_visual"] = wandb.Image(document["image"])
- # Create row and append to table
- values_list = list(document.values())
- main_table.add_data(*values_list)
- return main_table
- def upload_dataset(dataset_name):
- """Upload dataset from local database to Weights & Biases.
- Args:
- dataset_name: The name of the dataset in the Prodigy database.
- """
- # Check if wandb.init has been called
- if wandb.run is None:
- raise ValueError("You must call wandb.init() before upload_dataset()")
- with wb_telemetry.context(run=wandb.run) as tel:
- tel.feature.prodigy = True
- prodigy_db = util.get_module(
- "prodigy.components.db",
- required="`prodigy` library is required but not installed. Please see https://prodi.gy/docs/install",
- )
- # Retrieve and upload prodigy dataset
- database = prodigy_db.connect()
- data = database.get_dataset(dataset_name)
- array_dict_types = []
- schema = get_schema(data, {}, array_dict_types)
- for i, _d in enumerate(data):
- standardize(data[i], schema, array_dict_types)
- table = create_table(data)
- wandb.log({dataset_name: table})
- wandb.termlog(f"Prodigy dataset `{dataset_name}` uploaded.")
|