fine_tuning.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. from __future__ import annotations
  2. import base64
  3. import datetime
  4. import io
  5. import json
  6. import os
  7. import re
  8. import tempfile
  9. import time
  10. from typing import Any
  11. from packaging.version import parse
  12. import wandb
  13. from wandb import util
  14. from wandb.data_types import Table
  15. from wandb.sdk.lib import telemetry
  16. openai = util.get_module(
  17. name="openai",
  18. required="This integration requires `openai`. To install, please run `pip install openai`",
  19. lazy=False,
  20. )
  21. if parse(openai.__version__) < parse("1.12.0"):
  22. raise wandb.Error(
  23. f"This integration requires openai version 1.12.0 and above. Your current version is {openai.__version__} "
  24. "To fix, please `pip install -U openai`"
  25. )
  26. from openai import OpenAI # noqa: E402
  27. from openai.types.fine_tuning import FineTuningJob # noqa: E402
  28. from openai.types.fine_tuning.fine_tuning_job import ( # noqa: E402
  29. Error,
  30. Hyperparameters,
  31. )
  32. np = util.get_module(
  33. name="numpy",
  34. required="`numpy` not installed >> This integration requires numpy! To fix, please `pip install numpy`",
  35. lazy=False,
  36. )
  37. pd = util.get_module(
  38. name="pandas",
  39. required="`pandas` not installed >> This integration requires pandas! To fix, please `pip install pandas`",
  40. lazy=False,
  41. )
  42. class WandbLogger:
  43. """Log OpenAI fine-tunes to [Weights & Biases](https://wandb.me/openai-docs)."""
  44. _wandb_api: wandb.Api | None = None
  45. _logged_in: bool = False
  46. openai_client: OpenAI | None = None
  47. _run: wandb.Run | None = None
  48. @classmethod
  49. def sync(
  50. cls,
  51. fine_tune_job_id: str | None = None,
  52. openai_client: OpenAI | None = None,
  53. num_fine_tunes: int | None = None,
  54. project: str = "OpenAI-Fine-Tune",
  55. entity: str | None = None,
  56. overwrite: bool = False,
  57. wait_for_job_success: bool = True,
  58. log_datasets: bool = True,
  59. model_artifact_name: str = "model-metadata",
  60. model_artifact_type: str = "model",
  61. **kwargs_wandb_init: dict[str, Any],
  62. ) -> str:
  63. """Sync fine-tunes to Weights & Biases.
  64. :param fine_tune_job_id: The id of the fine-tune (optional)
  65. :param openai_client: Pass the `OpenAI()` client (optional)
  66. :param num_fine_tunes: Number of most recent fine-tunes to log when an fine_tune_job_id is not provided. By default, every fine-tune is synced.
  67. :param project: Name of the project where you're sending runs. By default, it is "GPT-3".
  68. :param entity: Username or team name where you're sending runs. By default, your default entity is used, which is usually your username.
  69. :param overwrite: Forces logging and overwrite existing wandb run of the same fine-tune.
  70. :param wait_for_job_success: Waits for the fine-tune to be complete and then log metrics to W&B. By default, it is True.
  71. :param model_artifact_name: Name of the model artifact that is logged
  72. :param model_artifact_type: Type of the model artifact that is logged
  73. """
  74. if openai_client is None:
  75. openai_client = OpenAI()
  76. cls.openai_client = openai_client
  77. if fine_tune_job_id:
  78. wandb.termlog("Retrieving fine-tune job...")
  79. fine_tune = openai_client.fine_tuning.jobs.retrieve(
  80. fine_tuning_job_id=fine_tune_job_id
  81. )
  82. fine_tunes = [fine_tune]
  83. else:
  84. # get list of fine_tune to log
  85. fine_tunes = openai_client.fine_tuning.jobs.list()
  86. if not fine_tunes or fine_tunes.data is None:
  87. wandb.termwarn("No fine-tune has been retrieved")
  88. return
  89. # Select the `num_fine_tunes` from the `fine_tunes.data` list.
  90. # If `num_fine_tunes` is None, it selects all items in the list (from start to end).
  91. # If for example, `num_fine_tunes` is 5, it selects the last 5 items in the list.
  92. # Note that the last items in the list are the latest fine-tune jobs.
  93. fine_tunes = fine_tunes.data[
  94. -num_fine_tunes if num_fine_tunes is not None else None :
  95. ]
  96. # log starting from oldest fine_tune
  97. show_individual_warnings = (
  98. fine_tune_job_id is not None or num_fine_tunes is not None
  99. )
  100. fine_tune_logged = []
  101. for fine_tune in fine_tunes:
  102. fine_tune_id = fine_tune.id
  103. # check run with the given `fine_tune_id` has not been logged already
  104. run_path = f"{project}/{fine_tune_id}"
  105. if entity is not None:
  106. run_path = f"{entity}/{run_path}"
  107. wandb_run = cls._get_wandb_run(run_path)
  108. if wandb_run:
  109. wandb_status = wandb_run.summary.get("status")
  110. if show_individual_warnings:
  111. if wandb_status == "succeeded" and not overwrite:
  112. wandb.termwarn(
  113. f"Fine-tune {fine_tune_id} has already been logged successfully at {wandb_run.url}. "
  114. "Use `overwrite=True` if you want to overwrite previous run"
  115. )
  116. elif wandb_status != "succeeded" or overwrite:
  117. if wandb_status != "succeeded":
  118. wandb.termwarn(
  119. f"A run for fine-tune {fine_tune_id} was previously created but didn't end successfully"
  120. )
  121. wandb.termlog(
  122. f"A new wandb run will be created for fine-tune {fine_tune_id} and previous run will be overwritten"
  123. )
  124. overwrite = True
  125. if wandb_status == "succeeded" and not overwrite:
  126. return
  127. # check if the user has not created a wandb run externally
  128. if wandb.run is None:
  129. cls._run = wandb.init(
  130. job_type="fine-tune",
  131. project=project,
  132. entity=entity,
  133. name=fine_tune_id,
  134. id=fine_tune_id,
  135. **kwargs_wandb_init,
  136. )
  137. else:
  138. # if a run exits - created externally
  139. cls._run = wandb.run
  140. if wait_for_job_success:
  141. fine_tune = cls._wait_for_job_success(fine_tune)
  142. cls._log_fine_tune(
  143. fine_tune,
  144. project,
  145. entity,
  146. overwrite,
  147. show_individual_warnings,
  148. log_datasets,
  149. model_artifact_name,
  150. model_artifact_type,
  151. **kwargs_wandb_init,
  152. )
  153. if not show_individual_warnings and not any(fine_tune_logged):
  154. wandb.termwarn("No new successful fine-tunes were found")
  155. return "🎉 wandb sync completed successfully"
  156. @classmethod
  157. def _wait_for_job_success(cls, fine_tune: FineTuningJob) -> FineTuningJob:
  158. wandb.termlog("Waiting for the OpenAI fine-tuning job to finish training...")
  159. wandb.termlog(
  160. "To avoid blocking, you can call `WandbLogger.sync` with `wait_for_job_success=False` after OpenAI training completes."
  161. )
  162. while True:
  163. if fine_tune.status == "succeeded":
  164. wandb.termlog(
  165. "Fine-tuning finished, logging metrics, model metadata, and run metadata to Weights & Biases"
  166. )
  167. return fine_tune
  168. if fine_tune.status == "failed":
  169. wandb.termwarn(
  170. f"Fine-tune {fine_tune.id} has failed and will not be logged"
  171. )
  172. return fine_tune
  173. if fine_tune.status == "cancelled":
  174. wandb.termwarn(
  175. f"Fine-tune {fine_tune.id} was cancelled and will not be logged"
  176. )
  177. return fine_tune
  178. time.sleep(10)
  179. fine_tune = cls.openai_client.fine_tuning.jobs.retrieve(
  180. fine_tuning_job_id=fine_tune.id
  181. )
  182. @classmethod
  183. def _log_fine_tune(
  184. cls,
  185. fine_tune: FineTuningJob,
  186. project: str,
  187. entity: str | None,
  188. overwrite: bool,
  189. show_individual_warnings: bool,
  190. log_datasets: bool,
  191. model_artifact_name: str,
  192. model_artifact_type: str,
  193. **kwargs_wandb_init: dict[str, Any],
  194. ):
  195. fine_tune_id = fine_tune.id
  196. status = fine_tune.status
  197. with telemetry.context(run=cls._run) as tel:
  198. tel.feature.openai_finetuning = True
  199. # check run completed successfully
  200. if status != "succeeded":
  201. if show_individual_warnings:
  202. wandb.termwarn(
  203. f'Fine-tune {fine_tune_id} has the status "{status}" and will not be logged'
  204. )
  205. return
  206. # check results are present
  207. try:
  208. results_id = fine_tune.result_files[0]
  209. try:
  210. encoded_results = cls.openai_client.files.content(
  211. file_id=results_id
  212. ).read()
  213. results = base64.b64decode(encoded_results).decode("utf-8")
  214. except Exception:
  215. # attempt to read as text, works for older jobs
  216. results = cls.openai_client.files.content(file_id=results_id).text
  217. except openai.NotFoundError:
  218. if show_individual_warnings:
  219. wandb.termwarn(
  220. f"Fine-tune {fine_tune_id} has no results and will not be logged"
  221. )
  222. return
  223. # update the config
  224. cls._run.config.update(cls._get_config(fine_tune))
  225. # log results
  226. df_results = pd.read_csv(io.StringIO(results))
  227. for _, row in df_results.iterrows():
  228. metrics = {k: v for k, v in row.items() if not np.isnan(v)}
  229. step = metrics.pop("step")
  230. if step is not None:
  231. step = int(step)
  232. cls._run.log(metrics, step=step)
  233. fine_tuned_model = fine_tune.fine_tuned_model
  234. if fine_tuned_model is not None:
  235. cls._run.summary["fine_tuned_model"] = fine_tuned_model
  236. # training/validation files and fine-tune details
  237. cls._log_artifacts(
  238. fine_tune,
  239. project,
  240. entity,
  241. log_datasets,
  242. overwrite,
  243. model_artifact_name,
  244. model_artifact_type,
  245. )
  246. # mark run as complete
  247. cls._run.summary["status"] = "succeeded"
  248. cls._run.finish()
  249. return True
  250. @classmethod
  251. def _ensure_logged_in(cls):
  252. if not cls._logged_in:
  253. if wandb.login():
  254. cls._logged_in = True
  255. else:
  256. raise Exception(
  257. "It appears you are not currently logged in to Weights & Biases. "
  258. "Please run `wandb login` in your terminal or `wandb.login()` in a notebook. "
  259. "Create a new API key at https://wandb.ai/settings and store it securely."
  260. )
  261. @classmethod
  262. def _get_wandb_run(cls, run_path: str):
  263. cls._ensure_logged_in()
  264. try:
  265. if cls._wandb_api is None:
  266. cls._wandb_api = wandb.Api()
  267. return cls._wandb_api.run(run_path)
  268. except Exception:
  269. return None
  270. @classmethod
  271. def _get_wandb_artifact(cls, artifact_path: str):
  272. cls._ensure_logged_in()
  273. try:
  274. if cls._wandb_api is None:
  275. cls._wandb_api = wandb.Api()
  276. return cls._wandb_api.artifact(artifact_path)
  277. except Exception:
  278. return None
  279. @classmethod
  280. def _get_config(cls, fine_tune: FineTuningJob) -> dict[str, Any]:
  281. config = dict(fine_tune)
  282. config["result_files"] = config["result_files"][0]
  283. if config.get("created_at"):
  284. config["created_at"] = datetime.datetime.fromtimestamp(
  285. config["created_at"]
  286. ).strftime("%Y-%m-%d %H:%M:%S")
  287. if config.get("finished_at"):
  288. config["finished_at"] = datetime.datetime.fromtimestamp(
  289. config["finished_at"]
  290. ).strftime("%Y-%m-%d %H:%M:%S")
  291. if config.get("hyperparameters"):
  292. config["hyperparameters"] = cls.sanitize(config["hyperparameters"])
  293. if config.get("error"):
  294. config["error"] = cls.sanitize(config["error"])
  295. return config
  296. @classmethod
  297. def _unpack_hyperparameters(cls, hyperparameters: Hyperparameters):
  298. # `Hyperparameters` object is not unpacking properly using `vars` or `__dict__`,
  299. # vars(hyperparameters) return {n_epochs: n} only.
  300. hyperparams = {}
  301. try:
  302. hyperparams["n_epochs"] = hyperparameters.n_epochs
  303. hyperparams["batch_size"] = hyperparameters.batch_size
  304. hyperparams["learning_rate_multiplier"] = (
  305. hyperparameters.learning_rate_multiplier
  306. )
  307. except Exception:
  308. # If unpacking fails, return the object to be logged as config
  309. return None
  310. return hyperparams
  311. @staticmethod
  312. def sanitize(input: Any) -> dict | list | str:
  313. valid_types = [bool, int, float, str]
  314. if isinstance(input, (Hyperparameters, Error)):
  315. return dict(input)
  316. if isinstance(input, dict):
  317. return {
  318. k: v if type(v) in valid_types else str(v) for k, v in input.items()
  319. }
  320. elif isinstance(input, list):
  321. return [v if type(v) in valid_types else str(v) for v in input]
  322. else:
  323. return str(input)
  324. @classmethod
  325. def _log_artifacts(
  326. cls,
  327. fine_tune: FineTuningJob,
  328. project: str,
  329. entity: str | None,
  330. log_datasets: bool,
  331. overwrite: bool,
  332. model_artifact_name: str,
  333. model_artifact_type: str,
  334. ) -> None:
  335. if log_datasets:
  336. wandb.termlog("Logging training/validation files...")
  337. # training/validation files
  338. training_file = fine_tune.training_file if fine_tune.training_file else None
  339. validation_file = (
  340. fine_tune.validation_file if fine_tune.validation_file else None
  341. )
  342. for file, prefix, artifact_type in (
  343. (training_file, "train", "training_files"),
  344. (validation_file, "valid", "validation_files"),
  345. ):
  346. if file is not None:
  347. cls._log_artifact_inputs(
  348. file, prefix, artifact_type, project, entity, overwrite
  349. )
  350. # fine-tune details
  351. fine_tune_id = fine_tune.id
  352. artifact = wandb.Artifact(
  353. model_artifact_name,
  354. type=model_artifact_type,
  355. metadata=dict(fine_tune),
  356. )
  357. with artifact.new_file("model_metadata.json", mode="w", encoding="utf-8") as f:
  358. dict_fine_tune = dict(fine_tune)
  359. dict_fine_tune["hyperparameters"] = cls.sanitize(
  360. dict_fine_tune["hyperparameters"]
  361. )
  362. dict_fine_tune["error"] = cls.sanitize(dict_fine_tune["error"])
  363. dict_fine_tune = cls.sanitize(dict_fine_tune)
  364. json.dump(dict_fine_tune, f, indent=2)
  365. cls._run.log_artifact(
  366. artifact,
  367. aliases=["latest", fine_tune_id],
  368. )
  369. @classmethod
  370. def _log_artifact_inputs(
  371. cls,
  372. file_id: str | None,
  373. prefix: str,
  374. artifact_type: str,
  375. project: str,
  376. entity: str | None,
  377. overwrite: bool,
  378. ) -> None:
  379. # get input artifact
  380. artifact_name = f"{prefix}-{file_id}"
  381. # sanitize name to valid wandb artifact name
  382. artifact_name = re.sub(r"[^a-zA-Z0-9_\-.]", "_", artifact_name)
  383. artifact_alias = file_id
  384. artifact_path = f"{project}/{artifact_name}:{artifact_alias}"
  385. if entity is not None:
  386. artifact_path = f"{entity}/{artifact_path}"
  387. artifact = cls._get_wandb_artifact(artifact_path)
  388. # create artifact if file not already logged previously
  389. if artifact is None or overwrite:
  390. # get file content
  391. try:
  392. file_content = cls.openai_client.files.content(file_id=file_id)
  393. except openai.NotFoundError:
  394. wandb.termerror(
  395. f"File {file_id} could not be retrieved. Make sure you have OpenAI permissions to download training/validation files"
  396. )
  397. return
  398. artifact = wandb.Artifact(artifact_name, type=artifact_type)
  399. with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
  400. tmp_file.write(file_content.content)
  401. tmp_file_path = tmp_file.name
  402. artifact.add_file(tmp_file_path, file_id)
  403. os.unlink(tmp_file_path)
  404. # create a Table
  405. try:
  406. table, n_items = cls._make_table(file_content.text)
  407. # Add table to the artifact.
  408. artifact.add(table, file_id)
  409. # Add the same table to the workspace.
  410. cls._run.log({f"{prefix}_data": table})
  411. # Update the run config and artifact metadata
  412. cls._run.config.update({f"n_{prefix}": n_items})
  413. artifact.metadata["items"] = n_items
  414. except Exception as e:
  415. wandb.termerror(
  416. f"Issue saving {file_id} as a Table to Artifacts, exception:\n '{e}'"
  417. )
  418. else:
  419. # log number of items
  420. cls._run.config.update({f"n_{prefix}": artifact.metadata.get("items")})
  421. cls._run.use_artifact(artifact, aliases=["latest", artifact_alias])
  422. @classmethod
  423. def _make_table(cls, file_content: str) -> tuple[Table, int]:
  424. table = wandb.Table(columns=["role: system", "role: user", "role: assistant"])
  425. df = pd.read_json(io.StringIO(file_content), orient="records", lines=True)
  426. for _idx, message in df.iterrows():
  427. messages = message.messages
  428. assert len(messages) == 3
  429. table.add_data(
  430. messages[0]["content"],
  431. messages[1]["content"],
  432. messages[2]["content"],
  433. )
  434. return table, len(df)