ニューラルネットワーク自作入門メモ

ニューラルネットワーク自作入門

最近AIや機械学習が自分の周りでも聞かれるようになってきたので、基本だけ抑えておこうと思い読んでみました。 読んでみてなんとなくニューラルネットワークがわかった気になれるいい本だったと思います。

文字認識もいいんですが、画像認識はどうなのかと思い、 CIFAR-10 and CIFAR-100 datasets で方式はそのままで試してみました。 データの読み取り方法は、サンプルにもありますが、

def unpickle(self, file):
    import _pickle
    fo = open(file, 'rb')
    dict = _pickle.load(fo, encoding='latin1')
    fo.close()
    return dict

こんな感じで、dict型で取得できます。この中に、dataとlabelsがあり、dataは32323(RGB)の値が入っているようなので、まずはこの1次元配列をそのまま入力層として利用してみました。 本のサンプルから変更しているところはこんな感じ

    epochs = 5

    for e in range(1, epochs):
        data_dict = n.unpickle('cifar-10-batches-py/data_batch_' + str(e))
        # go through all records in the training data set
        for record in range(len(data_dict['labels'])):
            # split the record by the ',' commas
            #all_values = record.split(',')
            # scale and shift the inputs
            inputs = (numpy.asfarray(data_dict['data'][record]) / 255.0 * 0.99) + 0.01
            # create the target output values (all 0.01, except the desired label which is 0.99)
            targets = numpy.zeros(output_nodes) + 0.01
            # all_values[0] is the target label for this record
            targets[int(data_dict['labels'][record])] = 0.99
            n.train(inputs, targets)
            pass
        pass
    for record in range(len(test_dict['labels'])):
    # split the record by the ',' commas
    #all_values = record.split(',')
    # correct answer is first value
    correct_label = int(test_dict['labels'][record])
    # scale and shift the inputs
    inputs = (numpy.asfarray(test_dict['data'][record]) / 255.0 * 0.99) + 0.01
    # query the network
    outputs = n.query(inputs)
    #print(outputs)
    # the index of the highest value corresponds to the label
    label = numpy.argmax(outputs)
    # append correct or incorrect to list
    if (label == correct_label):
    # network's answer matches correct answer, add 1 to scorecard
    scorecard.append(1)
    print (label, " == ", correct_label)
    else:
    # network's answer doesn't match correct answer, add 0 to scorecard
    scorecard.append(0)
    print (label, " != ", correct_label)
    pass

    pass

これでやってみると、正解率はだいたい20%ほど。 こんなもんかという感じですが、色をまとめて入れるのはどうかという気もするのでもう少し工夫したり、CNNをやってみてもいいかも。 なかなか楽しいですが、これを自分のサービスに入れ込もうとするにはいろいろやってみて応用を効かせられるようにならないといけないですね。