From d832df433717b4709b16e8fdb40923cc4e9fcb2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franti=C5=A1ek=20Kmje=C4=8D?= Date: Fri, 17 Jun 2022 17:01:03 +0200 Subject: [PATCH] Add cartpole --- gym_cartpole.py | 131 ++++++++++++++++++++++++++++++++++++++++++ gym_cartpole_data.txt | 100 ++++++++++++++++++++++++++++++++ 2 files changed, 231 insertions(+) create mode 100644 gym_cartpole.py create mode 100644 gym_cartpole_data.txt diff --git a/gym_cartpole.py b/gym_cartpole.py new file mode 100644 index 0000000..ca80126 --- /dev/null +++ b/gym_cartpole.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +import argparse +import datetime +import os +import re +from typing import Optional +os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") # Report only TF errors by default + +import numpy as np +from sklearn.model_selection import train_test_split +import tensorflow as tf + +# Credits to Milan Straka for making this task; +# github repo github.com/ufal/npfl129 +# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0) + +def evaluate_model( + model: tf.keras.Model, seed: int = 42, episodes: int = 100, render: bool = False, report_per_episode: bool = False +) -> float: + """Evaluate the given model on CartPole-v1 environment. + + Returns the average score achieved on the given number of episodes. + """ + import gym + + # Create the environment + env = gym.make("CartPole-v1") + env.seed(seed) + + # Evaluate the episodes + total_score = 0 + for episode in range(episodes): + observation, score, done = env.reset(), 0, False + while not done: + if render: + env.render() + + prediction = model(observation[np.newaxis, ...])[0].numpy() + if len(prediction) == 1: + action = 1 if prediction[0] > 0.5 else 0 + elif len(prediction) == 2: + action = np.argmax(prediction) + else: + raise ValueError("Unknown model output shape, only 1 or 2 outputs are supported") + + observation, reward, done, info = env.step(action) + score += reward + + total_score += score + if report_per_episode: + print("The episode {} finished with score {}.".format(episode + 1, score)) + return total_score / episodes + + +parser = argparse.ArgumentParser() +# These arguments will be set appropriately by ReCodEx, even if you change them. +parser.add_argument("--evaluate", default=False, action="store_true", help="Evaluate the given model") +parser.add_argument("--recodex", default=False, action="store_true", help="Evaluation in ReCodEx.") +parser.add_argument("--render", default=False, action="store_true", help="Render during evaluation") +parser.add_argument("--seed", default=42, type=int, help="Random seed.") +parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.") +# If you add more arguments, ReCodEx will keep them with your default values. +parser.add_argument("--batch_size", default=10, type=int, help="Batch size.") +parser.add_argument("--epochs", default=100, type=int, help="Number of epochs.") +parser.add_argument("--model", default="gym_cartpole_model.h5", type=str, help="Output model path.") +parser.add_argument("--hidden_layer", default=200, type=int, help="Size of the hidden layer.") + + +def main(args: argparse.Namespace) -> Optional[tf.keras.Model]: + # Fix random seeds and threads + np.random.seed(args.seed) + tf.random.set_seed(args.seed) + tf.config.threading.set_inter_op_parallelism_threads(args.threads) + tf.config.threading.set_intra_op_parallelism_threads(args.threads) + + if not args.evaluate: + # Create logdir name + args.logdir = os.path.join("logs", "{}-{}-{}".format( + os.path.basename(globals().get("__file__", "notebook")), + datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"), + ",".join(("{}={}".format(re.sub("(.)[^_]*_?", r"\1", k), v) for k, v in sorted(vars(args).items()))) + )) + + # Load the data + data = np.loadtxt("gym_cartpole_data.txt") + observations, labels = data[:, :-1], data[:, -1].astype(np.int32) + train_observations, test_observations, train_labels, test_labels = train_test_split( + observations, labels, test_size=0.2, random_state=args.seed, stratify=labels) + + # TODO: Create the model in the `model` variable. Note that + # the model can perform any of: + # - binary classification with 1 output and sigmoid activation; + # - two-class classification with 2 outputs and softmax activation. + FEATURES = 4 + model = tf.keras.Sequential() + model.add(tf.keras.layers.Input([FEATURES])) + model.add(tf.keras.layers.Dense(args.hidden_layer, activation=tf.nn.relu)) + model.add(tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)) + + # TODO: Prepare the model for training using the `model.compile` method. + model.compile( + optimizer=tf.optimizers.Adam(), + loss=tf.losses.BinaryCrossentropy(), + metrics=[tf.metrics.BinaryAccuracy("accuracy")] + ) + + tb_callback = tf.keras.callbacks.TensorBoard(args.logdir, histogram_freq=1) + model.fit( + train_observations, train_labels, + validation_data=(test_observations, test_labels), + batch_size=args.batch_size, epochs=args.epochs, + callbacks=[tb_callback] + ) + + # Save the model, without the optimizer state. + model.save(args.model, include_optimizer=False) + + else: + # Evaluating, either manually or in ReCodEx + model = tf.keras.models.load_model(args.model, compile=False) + + if args.recodex: + return model + else: + score = evaluate_model(model, seed=args.seed, render=args.render, report_per_episode=True) + print("The average score was {}.".format(score)) + + +if __name__ == "__main__": + args = parser.parse_args([] if "__file__" not in globals() else None) + main(args) diff --git a/gym_cartpole_data.txt b/gym_cartpole_data.txt new file mode 100644 index 0000000..e08b3d7 --- /dev/null +++ b/gym_cartpole_data.txt @@ -0,0 +1,100 @@ +-0.310 0.322 -0.022 0.042 0 +-0.505 -0.150 -0.022 -0.338 1 +-0.160 -0.738 0.080 0.254 1 +-0.471 -0.052 -0.000 0.140 0 +-0.918 -0.703 0.058 0.073 0 +1.336 -0.546 -0.063 0.391 1 +-1.175 0.138 0.162 -0.084 0 +-1.465 -0.183 0.033 0.006 0 +-0.282 -0.401 -0.010 0.323 0 +1.729 -0.590 0.017 0.325 1 +0.498 -0.004 -0.034 -0.081 0 +-0.830 -0.515 0.080 0.109 1 +0.000 0.043 0.021 -0.082 1 +0.209 -0.693 -0.038 0.221 0 +0.153 0.210 -0.025 -0.021 1 +0.278 0.786 -0.041 -0.400 0 +-0.169 -0.502 0.043 0.029 1 +0.751 0.010 -0.109 0.052 0 +0.053 0.372 -0.041 -0.145 0 +-0.367 0.606 0.026 -0.597 0 +0.317 -0.574 -0.098 0.292 0 +0.656 0.679 -0.026 -0.067 0 +-0.047 0.591 -0.122 -0.339 0 +-0.069 0.521 -0.021 -0.007 1 +-0.485 -0.369 -0.085 0.215 1 +-0.216 -0.918 -0.007 0.509 1 +-0.149 0.350 0.016 -0.158 0 +-0.072 -0.219 -0.066 -0.182 1 +-0.512 -0.326 -0.041 -0.002 1 +0.050 0.194 0.071 0.037 1 +0.025 0.179 -0.133 0.087 0 +-0.290 1.088 -0.146 -0.438 0 +-0.308 -0.713 -0.020 0.329 1 +-0.313 -0.347 0.004 0.067 0 +0.138 -0.063 -0.073 0.352 1 +-0.172 0.895 0.002 -0.488 0 +-0.068 -0.787 -0.052 0.282 0 +-0.296 -0.025 -0.015 -0.263 0 +-0.767 -0.355 -0.047 0.292 1 +0.422 0.720 -0.051 -0.218 1 +-1.656 0.217 0.038 -0.340 0 +-0.305 -0.010 -0.062 -0.104 0 +-0.972 0.426 0.069 -0.558 0 +-0.932 0.146 -0.043 0.305 1 +0.428 -0.738 -0.007 0.329 1 +0.727 -0.382 -0.026 0.370 0 +0.634 0.020 0.061 0.185 0 +0.532 0.573 -0.009 -0.593 0 +-0.511 0.188 -0.038 -0.050 1 +0.497 -0.555 -0.072 0.174 1 +0.035 -0.032 0.023 0.197 1 +-0.559 -0.136 0.033 -0.007 1 +0.369 0.543 -0.065 -0.003 1 +0.826 -0.175 0.018 -0.069 0 +-0.734 -0.697 0.076 -0.004 0 +-0.494 0.165 0.010 0.038 0 +0.407 -0.744 -0.099 0.403 1 +-0.010 0.367 0.021 -0.130 0 +0.675 -0.235 -0.073 0.486 1 +-0.385 -0.176 0.101 0.242 1 +-0.556 0.157 0.002 0.020 1 +0.217 -0.353 -0.056 0.366 1 +-0.042 -0.540 -0.033 0.420 0 +-0.014 -0.045 -0.008 -0.031 1 +0.357 -0.059 -0.025 0.312 1 +-0.154 0.207 0.023 -0.038 0 +-1.624 -0.570 0.006 0.049 1 +0.067 -0.154 -0.002 0.494 1 +-0.106 -0.427 -0.096 0.339 1 +-0.237 -0.313 -0.020 0.020 0 +-0.322 1.502 -0.048 -0.669 0 +-0.478 0.195 0.004 0.013 1 +-0.937 0.039 -0.039 -0.429 0 +-0.078 -0.205 -0.003 0.037 1 +-1.097 -0.204 0.045 -0.011 1 +0.080 0.201 0.120 -0.108 1 +0.004 -0.362 0.017 0.270 1 +0.878 0.305 -0.019 0.197 1 +1.108 -0.058 -0.033 0.175 1 +-0.637 -0.041 -0.090 -0.119 0 +-0.111 -0.322 -0.018 0.075 1 +-0.860 -0.325 0.045 -0.083 1 +-0.005 0.002 0.004 -0.010 1 +-0.947 -0.137 0.027 -0.207 0 +1.165 0.191 -0.010 -0.172 0 +-0.356 0.933 -0.008 -0.444 1 +-0.404 -0.507 0.035 0.276 1 +0.005 -0.376 -0.085 0.217 1 +-0.394 -0.341 -0.101 -0.207 0 +-0.293 0.766 0.063 -0.249 1 +-0.547 0.159 -0.013 -0.299 0 +-0.054 0.030 0.020 0.328 1 +-0.862 0.942 0.064 -0.360 1 +-0.147 -0.979 -0.026 0.514 1 +0.269 0.157 -0.004 0.037 0 +-0.042 -0.229 -0.007 0.081 0 +1.088 1.416 -0.001 -0.269 0 +-0.211 0.599 0.045 -0.316 0 +0.131 -0.550 -0.015 0.509 1 +-1.478 0.809 0.087 -0.895 0