proto_util.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. #
  2. from __future__ import annotations
  3. import json
  4. from typing import TYPE_CHECKING, Any
  5. from wandb.proto import wandb_internal_pb2 as pb
  6. if TYPE_CHECKING: # pragma: no cover
  7. from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
  8. from google.protobuf.message import Message
  9. from wandb.proto import wandb_telemetry_pb2 as tpb
  10. def dict_from_proto_list(obj_list: RepeatedCompositeFieldContainer) -> dict[str, Any]:
  11. result: dict[str, Any] = {}
  12. for item in obj_list:
  13. # Start from the root of the result dict
  14. current_level = result
  15. if len(item.nested_key) > 0:
  16. keys = list(item.nested_key)
  17. else:
  18. keys = [item.key]
  19. for key in keys[:-1]:
  20. if key not in current_level:
  21. current_level[key] = {}
  22. # Move the reference deeper into the nested dictionary
  23. current_level = current_level[key]
  24. # Set the value at the final key location, parsing JSON from the value_json field
  25. final_key = keys[-1]
  26. current_level[final_key] = json.loads(item.value_json)
  27. return result
  28. def _result_from_record(record: pb.Record) -> pb.Result:
  29. result = pb.Result(uuid=record.uuid, control=record.control)
  30. return result
  31. def _assign_record_num(record: pb.Record, record_num: int) -> None:
  32. record.num = record_num
  33. def _assign_end_offset(record: pb.Record, end_offset: int) -> None:
  34. record.control.end_offset = end_offset
  35. def proto_encode_to_dict(
  36. pb_obj: tpb.TelemetryRecord | pb.MetricRecord,
  37. ) -> dict[int, Any]:
  38. data: dict[int, Any] = dict()
  39. fields = pb_obj.ListFields()
  40. for desc, value in fields:
  41. if desc.name.startswith("_"):
  42. continue
  43. if desc.type == desc.TYPE_STRING:
  44. data[desc.number] = value
  45. elif desc.type == desc.TYPE_INT32:
  46. data[desc.number] = value
  47. elif desc.type == desc.TYPE_ENUM:
  48. data[desc.number] = value
  49. elif desc.type == desc.TYPE_MESSAGE:
  50. nested = value.ListFields()
  51. bool_msg = all(d.type == d.TYPE_BOOL for d, _ in nested)
  52. if bool_msg:
  53. items = [d.number for d, v in nested if v]
  54. if items:
  55. data[desc.number] = items
  56. else:
  57. # TODO: for now this code only handles sub-messages with strings
  58. md = {}
  59. for d, v in nested:
  60. if not v or d.type != d.TYPE_STRING:
  61. continue
  62. md[d.number] = v
  63. data[desc.number] = md
  64. return data
  65. def message_to_dict(
  66. message: Message,
  67. ) -> dict[str, Any]:
  68. """Convert a protobuf message into a dictionary."""
  69. from google.protobuf.json_format import MessageToDict
  70. return MessageToDict(message, preserving_proto_field_name=True)