Adam Code implementing parameter optimization algorithm Adam[Python]

スポンサーリンク

Adam Code implementing parameter optimization algorithm Adam[Python]

Adam Class Definition

Adam’s weight update is as follows:

If this is implemented, the following code will be turned out.

Define adam class.

# Adam
class Adam:
    def __init__(self, lr=0.001, beta1=0.9, beta2=0.999):
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.iter = 0
        self.m = None
        self.v = None
        
def update(self, params, grads):
        if self.m is None:
            self.m, self.v = {}, {}
            for key, val in params.items():
                self.m[key] = np.zeros_like(val)
                self.v[key] = np.zeros_like(val)
        
self.iter += 1
        
for key in params.keys():
            self.m[key] = self.beta1 * self.m[key] + (1 - self.beta1) * grads[key]
            self.v[key] = self.beta2 * self.v[key] + (1 - self.beta2) * (grads[key]**2)
            m_unbias = self.m[key] / (1 - self.beta1**self.iter)
            v_unbias = self.v[key] / (1 - self.beta2**self.iter)
            params[key] -= self.lr * m_unbias / (np.sqrt(v_unbias) + 1e-7)

Adam is used in the train function used in learning

If this is incorporated in the train function used in model learning, it becomes the following code.

def train(x, t, eps=0.005):
    global W, b # Weights and Biases
    batch_size = x.shape[0]
    
t_hat = softmax(np.matmul(x, W) + b)

cost = (- t * np_log(t_hat)).sum(axis=1).mean()
    delta = t_hat - t

dW = np.matmul(x.T, delta) / batch_size
    db = np.matmul(np.ones(shape=(batch_size,)), delta) / batch_size

# Update parameters in Adam
    params = {'W': W, 'b': b}
    grads = {'W': dW, 'b': db}
    adam.update(params, grads)
    return cost

The program for basic model learning is written in detail below, so see ↓

Reference

Popular Articles

Popular articles are here.

CUDA、cuDNNのバージョンをターミナルで調べるコマンド
【Pytorch】テンソルを連結する方法(cat・stack)
【Pytorch】テンソルの次元を追加・削除する方法【dim】
【Protobuf】"TypeError: Descriptors cannot not be created directly."を解決する【solved】
【Python】Tensorflowをダウングレード・アップグレードするコマンド

タイトルとURLをコピーしました