Batch#
In this tutorial, we will introduce the Batch to you, which is the most basic data structure in Tianshou. You can simply considered Batch as a numpy version of python dictionary.
Show code cell content
import numpy as np
from tianshou.data import Batch
import torch
import pickle
data = Batch(a=4, b=[5, 5], c="2312312", d=("a", -2, -3))
print(data)
print(data.b)
Batch(
a: array(4),
b: array([5, 5]),
c: '2312312',
d: array(['a', '-2', '-3'], dtype=object),
)
[5 5]
A batch is simply a dictionary which stores all passed in data as key-value pairs, and automatically turns the value into a numpy array if possible.
Why we need Batch in Tianshou?#
The motivation behind the implementation of Batch module is simple. In DRL, you need to handle a lot of dictionary-format data. For instance most algorithms would require you to store state, action, and reward data for every step when interacting with the environment. All these data can be organized as a dictionary and a Batch module helps Tianshou unify the interface of a diverse set of algorithms. Plus, Batch supports advanced indexing, concatenation and splitting, formatting print just like any other numpy array, which may be very helpful for developers.
Basic Usages#
Initialization#
Batch can be converted directly from a python dictionary, and all data structure will be converted to numpy array if possible.
# converted from a python library
print("========================================")
batch1 = Batch({"a": [4, 4], "b": (5, 5)})
print(batch1)
# initialization of batch2 is equivalent to batch1
print("========================================")
batch2 = Batch(a=[4, 4], b=(5, 5))
print(batch2)
# the dictionary can be nested, and it will be turned into a nested Batch
print("========================================")
data = {
"action": np.array([1.0, 2.0, 3.0]),
"reward": 3.66,
"obs": {
"rgb_obs": np.zeros((3, 3)),
"flatten_obs": np.ones(5),
},
}
batch3 = Batch(data, extra="extra_string")
print(batch3)
# batch3.obs is also a Batch
print(type(batch3.obs))
print(batch3.obs.rgb_obs)
# a list of dictionary/Batch will automatically be concatenated/stacked, providing convenience if you
# want to use parallelized environments to collect data.
print("========================================")
batch4 = Batch([data] * 3)
print(batch4)
print(batch4.obs.rgb_obs.shape)
========================================
Batch(
a: array([4, 4]),
b: array([5, 5]),
)
========================================
Batch(
a: array([4, 4]),
b: array([5, 5]),
)
========================================
Batch(
action: array([1., 2., 3.]),
reward: array(3.66),
obs: Batch(
rgb_obs: array([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]),
flatten_obs: array([1., 1., 1., 1., 1.]),
),
extra: 'extra_string',
)
<class 'tianshou.data.batch.Batch'>
[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]
========================================
Batch(
action: array([[1., 2., 3.],
[1., 2., 3.],
[1., 2., 3.]]),
reward: array([3.66, 3.66, 3.66]),
obs: Batch(
flatten_obs: array([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]]),
rgb_obs: array([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]]),
),
)
(3, 3, 3)
Getting access to data#
You can conveniently search or change the key-value pair in the Batch just as if it is a python dictionary.
batch1 = Batch({"a": [4, 4], "b": (5, 5)})
print(batch1)
# add or delete key-value pair in batch1
print("========================================")
batch1.c = Batch(c1=np.arange(3), c2=False)
del batch1.a
print(batch1)
# access value by key
print("========================================")
assert batch1["c"] is batch1.c
print("c" in batch1)
# traverse the Batch
print("========================================")
for key, value in batch1.items():
print(str(key) + ": " + str(value))
Batch(
a: array([4, 4]),
b: array([5, 5]),
)
========================================
Batch(
b: array([5, 5]),
c: Batch(
c1: array([0, 1, 2]),
c2: array(False),
),
)
========================================
True
========================================
b: [5 5]
c: Batch(
c1: array([0, 1, 2]),
c2: array(False),
)
Indexing and Slicing#
If all values in Batch share the same shape in certain dimensions, Batch can support advanced indexing and slicing just like a normal numpy array.
# Let us suppose we've got 4 environments, each returns a step of data
step_datas = [
{
"act": np.random.randint(10),
"rew": 0.0,
"obs": np.ones((3, 3)),
"info": {"done": np.random.choice(2), "failed": False},
}
for _ in range(4)
]
batch = Batch(step_datas)
print(batch)
print(batch.shape)
# advanced indexing is supported, if we only want to select data in a given set of environments
print("========================================")
print(batch[0])
print(batch[[0, 3]])
# slicing is also supported
print("========================================")
print(batch[-2:])
Batch(
rew: array([0., 0., 0., 0.]),
act: array([3, 5, 1, 3]),
obs: array([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]),
info: Batch(
done: array([1, 0, 0, 0]),
failed: array([False, False, False, False]),
),
)
[4]
========================================
Batch(
rew: 0.0,
act: 3,
obs: array([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]),
info: Batch(
done: 1,
failed: False,
),
)
Batch(
rew: array([0., 0.]),
act: array([3, 3]),
obs: array([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]),
info: Batch(
done: array([1, 0]),
failed: array([False, False]),
),
)
========================================
Batch(
rew: array([0., 0.]),
act: array([1, 3]),
obs: array([[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]]]),
info: Batch(
done: array([0, 0]),
failed: array([False, False]),
),
)
Aggregation and Splitting#
Again, just like a numpy array. Play the example code below.
# concat batches with compatible keys
# try incompatible keys yourself if you feel curious
print("========================================")
b1 = Batch(a=[{"b": np.float64(1.0), "d": Batch(e=np.array(3.0))}])
b2 = Batch(a=[{"b": np.float64(4.0), "d": {"e": np.array(6.0)}}])
b12_cat_out = Batch.cat([b1, b2])
print(b1)
print(b2)
print(b12_cat_out)
# stack batches with compatible keys
# try incompatible keys yourself if you feel curious
print("========================================")
b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))
b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))
b34_stack = Batch.stack((b3, b4), axis=1)
print(b3)
print(b4)
print(b34_stack)
# split the batch into small batches of size 1, breaking the order of the data
print("========================================")
print(type(b34_stack.split(1)))
print(list(b34_stack.split(1, shuffle=True)))
========================================
Batch(
a: Batch(
b: array([1.]),
d: Batch(
e: array([3.]),
),
),
)
Batch(
a: Batch(
b: array([4.]),
d: Batch(
e: array([6.]),
),
),
)
Batch(
a: Batch(
b: array([1., 4.]),
d: Batch(
e: array([3., 6.]),
),
),
)
========================================
Batch(
a: array([[0., 0.],
[0., 0.],
[0., 0.]]),
b: array([[1., 1., 1.],
[1., 1., 1.]]),
c: Batch(
d: array([[1],
[2]]),
),
)
Batch(
a: array([[1., 1.],
[1., 1.],
[1., 1.]]),
b: array([[1., 1., 1.],
[1., 1., 1.]]),
c: Batch(
d: array([[0],
[3]]),
),
)
Batch(
a: array([[[0., 0.],
[1., 1.]],
[[0., 0.],
[1., 1.]],
[[0., 0.],
[1., 1.]]]),
b: array([[[1., 1., 1.],
[1., 1., 1.]],
[[1., 1., 1.],
[1., 1., 1.]]]),
c: Batch(
d: array([[[1],
[0]],
[[2],
[3]]]),
),
)
========================================
<class 'generator'>
[Batch(
a: array([[[0., 0.],
[1., 1.]]]),
b: array([[[1., 1., 1.],
[1., 1., 1.]]]),
c: Batch(
d: array([[[2],
[3]]]),
),
), Batch(
a: array([[[0., 0.],
[1., 1.]]]),
b: array([[[1., 1., 1.],
[1., 1., 1.]]]),
c: Batch(
d: array([[[1],
[0]]]),
),
)]
Data type converting#
Besides numpy array, Batch actually also supports Torch Tensor. The usages are exactly the same. Cool, isn’t it?
batch1 = Batch(a=np.arange(2), b=torch.zeros((2, 2)))
batch2 = Batch(a=np.arange(2), b=torch.ones((2, 2)))
batch_cat = Batch.cat([batch1, batch2, batch1])
print(batch_cat)
Batch(
a: array([0, 1, 0, 1, 0, 1]),
b: tensor([[0., 0.],
[0., 0.],
[1., 1.],
[1., 1.],
[0., 0.],
[0., 0.]]),
)
You can convert the data type easily, if you no longer want to use hybrid data type anymore.
batch_cat.to_numpy()
print(batch_cat)
batch_cat.to_torch()
print(batch_cat)
Batch(
a: array([0, 1, 0, 1, 0, 1]),
b: array([[0., 0.],
[0., 0.],
[1., 1.],
[1., 1.],
[0., 0.],
[0., 0.]], dtype=float32),
)
Batch(
a: tensor([0, 1, 0, 1, 0, 1]),
b: tensor([[0., 0.],
[0., 0.],
[1., 1.],
[1., 1.],
[0., 0.],
[0., 0.]]),
)
Batch is even serializable, just in case you may need to save it to disk or restore it.
batch = Batch(obs=Batch(a=0.0, c=torch.Tensor([1.0, 2.0])), np=np.zeros([3, 4]))
batch_pk = pickle.loads(pickle.dumps(batch))
print(batch_pk)
Batch(
obs: Batch(
a: array(0.),
c: tensor([1., 2.]),
),
np: array([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]),
)
Further Reading#
Would like to learn more advanced usages of Batch? Feel curious about how data is organized inside the Batch? Check the documentation and other tutorials for more details.