迭代器

迭代器(Iterator)的定义是,实现了 __iter__()__next__() 方法的对象,其中__iter__()方法返回迭代器自身(即 self),__next__()方法返回下一个元素,若没有元素则抛出 StopIteration 异常。

听上去优点抽象,我们还是来看一个例子,根据上面的定义,只要写一个类,并实现 __iter__()__next__() 两个魔法方法,就是一个迭代器了,例如

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class MyIterator:
def __init__(self, limit):
self.limit = limit
self.counter = 0

def __iter__(self):
return self

def __next__(self):
if self.counter < self.limit:
result = self.counter
self.counter += 1
return result
else:
raise StopIteration

函数的功能很简单,初始值为0,返回的下一个元素是上一个元素加1,直至返回值等于设定的self.limit,则不返回任何元素,并按照迭代器的定义,抛出 StopIteration 异常。让我们来看看这个迭代器用起来怎么样

1
2
3
my_iter = MyIterator(5)
for i in range(5):
print(f"Iterator output: {next(my_iter)}")

实例化上述实现的类,得到一个迭代器对象,使用next()就可以调用__next__()方法,返回下一个元素。上述代码输出为

1
2
3
4
5
Iterator output: 0
Iterator output: 1
Iterator output: 2
Iterator output: 3
Iterator output: 4

而比较常用的写法是配合for循环

1
2
3
my_iter = MyIterator(5)
for i in my_iter:
print(i) # 输出: 0, 1, 2, 3, 4

这种写法会自动完整迭代一遍迭代器,并在完成迭代后自动停止。

生成器

生成器(Generator)是一种简洁的迭代器实现方式,使用 yield 关键字替代函数中的return,实现暂停函数并保留状态的效果。例如上面的迭代器,使用生成器的写法就是

1
2
3
4
5
def my_generator(limit):
counter = 0
while counter < limit:
yield counter
counter += 1

yield 关键字可以理解为一个可以进行多次返回的return,每运行一次 yield 就进行一次返回,我们配合下面的代码来理解

1
2
3
my_gen = my_generator(5)
for i in range(5):
print(f"Generator output: {next(my_gen)}")

上述代码中,不难看出,实例化的对象一共会运行5次yield关键字,也就是一共会进行5次返回,而与迭代器的使用方法一样,使用next()就可以让生成器进行一次返回,因此上述代码输出为

1
2
3
4
5
Generator output: 0
Generator output: 1
Generator output: 2
Generator output: 3
Generator output: 4

同理,生成器也可以配合for循环,实现自动迭代

1
2
3
my_gen = my_generator(5)
for i in my_gen:
print(i) # 输出: 0, 1, 2, 3, 4

最后,还是要再次强调,生成器本质上就是迭代器,只是其实现更加简洁,不需要写成类的形式,用函数的形式就可以实现。

迭代器与生成器的使用场景

讲完了使用方法,最后我们来说说为什么要使用迭代器与生成器。我们不如再来看一段代码,是一个简单的for循环

1
2
for i in range(5):
print(i)

这段代码实现的功能与上述示例代码完全相同,那为什么还要大费周章使用复杂的迭代器与生成器呢?答案是节省内存

我们来对比一下两种实现方式,同样是输出5个数字,使用for循环时,代码预先生成了range(5),即包含了5个数字的列表,然后再从中一个一个的读取数字,而使用迭代器时,在初始化迭代器时,我们并没有生成全部的5个数字,而是在每一次需要输出时,才生成本次输出需要的数字。也就是说,使用迭代器与生成器时,不需要一次性将所有的数据载入内存,而是仅在需要使用时载入,可以达到节省内存的目的

综上,当使用for循环会导致爆内存时,就要使用迭代器与生成器(能用简单代码解决的,干嘛用这么复杂的语法,例如,深度学习中的小批量数据读取

1
2
3
4
5
6
7
8
9
10
11
12
# 小批量数据读取
def data_iter(batch_size, features, labels):
# batch_size:批量大小,features:特征,labels:标签

data_size = np.size(features, 0) # 数据集数量
index = np.random.permutation(data_size) # 生成随机索引

for i in range(0, data_size, batch_size):
batch_index = index[i: min(i + batch_size, data_size)] # batch的索引列表
X, y = features[batch_index], labels[batch_index]
X, y = X.reshape(np.size(batch_index), -1), y.reshape(np.size(batch_index), -1) # 固定形状
yield X, y

深度学习中使用的数据集一般都比较大,如果一次性将全部数据载入显存中进行训练,那么很容易炸显存,这种场景就需要使用迭代器与生成器,每一次训练的时候就将一个batch_size的数据载入显存,训练完成后清除,再载入下一个batch_size,就可以规避掉爆显存的问题了。

注:当然了,这篇文章是笔者在手写梯度下降的过程中,遇到了爆内存的问题,最后用生成器解决了,故写一篇简单的博客记录一下,迭代器与生成器的使用场景当然不止这一个啦。