utils.py 747 B

123456789101112131415161718192021
  1. import tensorflow as tf
  2. from sklearn.datasets import load_iris
  3. from sklearn.model_selection import train_test_split
  4. from sklearn.preprocessing import OneHotEncoder
  5. def get_iris_data(test_size=0.2):
  6. iris_data = load_iris()
  7. x = iris_data.data
  8. y = iris_data.target.reshape(-1, 1)
  9. encoder = OneHotEncoder(sparse=False)
  10. y = encoder.fit_transform(y)
  11. train_x, test_x, train_y, test_y = train_test_split(x, y)
  12. return train_x, train_y, test_x, test_y
  13. def set_keras_threads(threads):
  14. # We set threads here to avoid contention, as Keras
  15. # is heavily parallelized across multiple cores.
  16. tf.config.threading.set_inter_op_parallelism_threads(threads)
  17. tf.config.threading.set_intra_op_parallelism_threads(threads)