Add cartpole

This commit is contained in:
František Kmječ 2022-06-17 17:01:03 +02:00
parent 3d580c8ef2
commit d832df4337
2 changed files with 231 additions and 0 deletions

131
gym_cartpole.py Normal file
View File

@ -0,0 +1,131 @@
#!/usr/bin/env python3
import argparse
import datetime
import os
import re
from typing import Optional
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") # Report only TF errors by default
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
# Credits to Milan Straka for making this task;
# github repo github.com/ufal/npfl129
# Attribution-ShareAlike 4.0 International (CC BY-SA 4.0)
def evaluate_model(
model: tf.keras.Model, seed: int = 42, episodes: int = 100, render: bool = False, report_per_episode: bool = False
) -> float:
"""Evaluate the given model on CartPole-v1 environment.
Returns the average score achieved on the given number of episodes.
"""
import gym
# Create the environment
env = gym.make("CartPole-v1")
env.seed(seed)
# Evaluate the episodes
total_score = 0
for episode in range(episodes):
observation, score, done = env.reset(), 0, False
while not done:
if render:
env.render()
prediction = model(observation[np.newaxis, ...])[0].numpy()
if len(prediction) == 1:
action = 1 if prediction[0] > 0.5 else 0
elif len(prediction) == 2:
action = np.argmax(prediction)
else:
raise ValueError("Unknown model output shape, only 1 or 2 outputs are supported")
observation, reward, done, info = env.step(action)
score += reward
total_score += score
if report_per_episode:
print("The episode {} finished with score {}.".format(episode + 1, score))
return total_score / episodes
parser = argparse.ArgumentParser()
# These arguments will be set appropriately by ReCodEx, even if you change them.
parser.add_argument("--evaluate", default=False, action="store_true", help="Evaluate the given model")
parser.add_argument("--recodex", default=False, action="store_true", help="Evaluation in ReCodEx.")
parser.add_argument("--render", default=False, action="store_true", help="Render during evaluation")
parser.add_argument("--seed", default=42, type=int, help="Random seed.")
parser.add_argument("--threads", default=1, type=int, help="Maximum number of threads to use.")
# If you add more arguments, ReCodEx will keep them with your default values.
parser.add_argument("--batch_size", default=10, type=int, help="Batch size.")
parser.add_argument("--epochs", default=100, type=int, help="Number of epochs.")
parser.add_argument("--model", default="gym_cartpole_model.h5", type=str, help="Output model path.")
parser.add_argument("--hidden_layer", default=200, type=int, help="Size of the hidden layer.")
def main(args: argparse.Namespace) -> Optional[tf.keras.Model]:
# Fix random seeds and threads
np.random.seed(args.seed)
tf.random.set_seed(args.seed)
tf.config.threading.set_inter_op_parallelism_threads(args.threads)
tf.config.threading.set_intra_op_parallelism_threads(args.threads)
if not args.evaluate:
# Create logdir name
args.logdir = os.path.join("logs", "{}-{}-{}".format(
os.path.basename(globals().get("__file__", "notebook")),
datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S"),
",".join(("{}={}".format(re.sub("(.)[^_]*_?", r"\1", k), v) for k, v in sorted(vars(args).items())))
))
# Load the data
data = np.loadtxt("gym_cartpole_data.txt")
observations, labels = data[:, :-1], data[:, -1].astype(np.int32)
train_observations, test_observations, train_labels, test_labels = train_test_split(
observations, labels, test_size=0.2, random_state=args.seed, stratify=labels)
# TODO: Create the model in the `model` variable. Note that
# the model can perform any of:
# - binary classification with 1 output and sigmoid activation;
# - two-class classification with 2 outputs and softmax activation.
FEATURES = 4
model = tf.keras.Sequential()
model.add(tf.keras.layers.Input([FEATURES]))
model.add(tf.keras.layers.Dense(args.hidden_layer, activation=tf.nn.relu))
model.add(tf.keras.layers.Dense(1, activation=tf.nn.sigmoid))
# TODO: Prepare the model for training using the `model.compile` method.
model.compile(
optimizer=tf.optimizers.Adam(),
loss=tf.losses.BinaryCrossentropy(),
metrics=[tf.metrics.BinaryAccuracy("accuracy")]
)
tb_callback = tf.keras.callbacks.TensorBoard(args.logdir, histogram_freq=1)
model.fit(
train_observations, train_labels,
validation_data=(test_observations, test_labels),
batch_size=args.batch_size, epochs=args.epochs,
callbacks=[tb_callback]
)
# Save the model, without the optimizer state.
model.save(args.model, include_optimizer=False)
else:
# Evaluating, either manually or in ReCodEx
model = tf.keras.models.load_model(args.model, compile=False)
if args.recodex:
return model
else:
score = evaluate_model(model, seed=args.seed, render=args.render, report_per_episode=True)
print("The average score was {}.".format(score))
if __name__ == "__main__":
args = parser.parse_args([] if "__file__" not in globals() else None)
main(args)

100
gym_cartpole_data.txt Normal file
View File

@ -0,0 +1,100 @@
-0.310 0.322 -0.022 0.042 0
-0.505 -0.150 -0.022 -0.338 1
-0.160 -0.738 0.080 0.254 1
-0.471 -0.052 -0.000 0.140 0
-0.918 -0.703 0.058 0.073 0
1.336 -0.546 -0.063 0.391 1
-1.175 0.138 0.162 -0.084 0
-1.465 -0.183 0.033 0.006 0
-0.282 -0.401 -0.010 0.323 0
1.729 -0.590 0.017 0.325 1
0.498 -0.004 -0.034 -0.081 0
-0.830 -0.515 0.080 0.109 1
0.000 0.043 0.021 -0.082 1
0.209 -0.693 -0.038 0.221 0
0.153 0.210 -0.025 -0.021 1
0.278 0.786 -0.041 -0.400 0
-0.169 -0.502 0.043 0.029 1
0.751 0.010 -0.109 0.052 0
0.053 0.372 -0.041 -0.145 0
-0.367 0.606 0.026 -0.597 0
0.317 -0.574 -0.098 0.292 0
0.656 0.679 -0.026 -0.067 0
-0.047 0.591 -0.122 -0.339 0
-0.069 0.521 -0.021 -0.007 1
-0.485 -0.369 -0.085 0.215 1
-0.216 -0.918 -0.007 0.509 1
-0.149 0.350 0.016 -0.158 0
-0.072 -0.219 -0.066 -0.182 1
-0.512 -0.326 -0.041 -0.002 1
0.050 0.194 0.071 0.037 1
0.025 0.179 -0.133 0.087 0
-0.290 1.088 -0.146 -0.438 0
-0.308 -0.713 -0.020 0.329 1
-0.313 -0.347 0.004 0.067 0
0.138 -0.063 -0.073 0.352 1
-0.172 0.895 0.002 -0.488 0
-0.068 -0.787 -0.052 0.282 0
-0.296 -0.025 -0.015 -0.263 0
-0.767 -0.355 -0.047 0.292 1
0.422 0.720 -0.051 -0.218 1
-1.656 0.217 0.038 -0.340 0
-0.305 -0.010 -0.062 -0.104 0
-0.972 0.426 0.069 -0.558 0
-0.932 0.146 -0.043 0.305 1
0.428 -0.738 -0.007 0.329 1
0.727 -0.382 -0.026 0.370 0
0.634 0.020 0.061 0.185 0
0.532 0.573 -0.009 -0.593 0
-0.511 0.188 -0.038 -0.050 1
0.497 -0.555 -0.072 0.174 1
0.035 -0.032 0.023 0.197 1
-0.559 -0.136 0.033 -0.007 1
0.369 0.543 -0.065 -0.003 1
0.826 -0.175 0.018 -0.069 0
-0.734 -0.697 0.076 -0.004 0
-0.494 0.165 0.010 0.038 0
0.407 -0.744 -0.099 0.403 1
-0.010 0.367 0.021 -0.130 0
0.675 -0.235 -0.073 0.486 1
-0.385 -0.176 0.101 0.242 1
-0.556 0.157 0.002 0.020 1
0.217 -0.353 -0.056 0.366 1
-0.042 -0.540 -0.033 0.420 0
-0.014 -0.045 -0.008 -0.031 1
0.357 -0.059 -0.025 0.312 1
-0.154 0.207 0.023 -0.038 0
-1.624 -0.570 0.006 0.049 1
0.067 -0.154 -0.002 0.494 1
-0.106 -0.427 -0.096 0.339 1
-0.237 -0.313 -0.020 0.020 0
-0.322 1.502 -0.048 -0.669 0
-0.478 0.195 0.004 0.013 1
-0.937 0.039 -0.039 -0.429 0
-0.078 -0.205 -0.003 0.037 1
-1.097 -0.204 0.045 -0.011 1
0.080 0.201 0.120 -0.108 1
0.004 -0.362 0.017 0.270 1
0.878 0.305 -0.019 0.197 1
1.108 -0.058 -0.033 0.175 1
-0.637 -0.041 -0.090 -0.119 0
-0.111 -0.322 -0.018 0.075 1
-0.860 -0.325 0.045 -0.083 1
-0.005 0.002 0.004 -0.010 1
-0.947 -0.137 0.027 -0.207 0
1.165 0.191 -0.010 -0.172 0
-0.356 0.933 -0.008 -0.444 1
-0.404 -0.507 0.035 0.276 1
0.005 -0.376 -0.085 0.217 1
-0.394 -0.341 -0.101 -0.207 0
-0.293 0.766 0.063 -0.249 1
-0.547 0.159 -0.013 -0.299 0
-0.054 0.030 0.020 0.328 1
-0.862 0.942 0.064 -0.360 1
-0.147 -0.979 -0.026 0.514 1
0.269 0.157 -0.004 0.037 0
-0.042 -0.229 -0.007 0.081 0
1.088 1.416 -0.001 -0.269 0
-0.211 0.599 0.045 -0.316 0
0.131 -0.550 -0.015 0.509 1
-1.478 0.809 0.087 -0.895 0