decision_boundaries.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from warnings import simplefilter
  2. import wandb
  3. # ignore all future warnings
  4. simplefilter(action="ignore", category=FutureWarning)
  5. def decision_boundaries(
  6. decision_boundary_x,
  7. decision_boundary_y,
  8. decision_boundary_color,
  9. train_x,
  10. train_y,
  11. train_color,
  12. test_x,
  13. test_y,
  14. test_color,
  15. ):
  16. x_dict, y_dict, color_dict = [], [], []
  17. for i in range(min(len(decision_boundary_x), 100)):
  18. x_dict.append(decision_boundary_x[i])
  19. y_dict.append(decision_boundary_y[i])
  20. color_dict.append(decision_boundary_color)
  21. for i in range(300):
  22. x_dict.append(test_x[i])
  23. y_dict.append(test_y[i])
  24. color_dict.append(test_color[i])
  25. for i in range(min(len(train_x), 600)):
  26. x_dict.append(train_x[i])
  27. y_dict.append(train_y[i])
  28. color_dict.append(train_color[i])
  29. return wandb.visualize(
  30. "wandb/decision_boundaries/v1",
  31. wandb.Table(
  32. columns=["x", "y", "color"],
  33. data=[[x_dict[i], y_dict[i], color_dict[i]] for i in range(len(x_dict))],
  34. ),
  35. )