gll89的个人博客分享 http://blog.sciencenet.cn/u/gll89

博文

Pytorch--__init__, __call__ forward functions in a class

已有 3822 次阅读 2017-11-9 07:25 |个人分类:Python|系统分类:科研笔记

The following is a segment of codes:

=================================

from torch import nn

class RNN(nn.Module):  # nn.Module means the class of RNN extends the class of nn.Module

   def __init__(self, input_size, hidden_size, output_size):
       super(RNN, self).__init__()
       self.hidden_size = hidden_size

       self.i2h = nn.Linear(input_size+hidden_size, hidden_size)
       self.h2o = nn.Linear(input_size+hidden_size, output_size)
       self.softmax = nn.LogSoftmax()

   def forward(self, input, hidden):   //Defines the computation performed at every call.

       # pdb.set_trace()
       combined = torch.cat((input, hidden), 1)
       hidden = self.i2h(combined)
       output = self.h2o(combined)
       output = self.softmax(output)
       return output, hidden

n_hidden = 128
input = Variable(lineToTensor('Albert'))
hidden = Variable(torch.zeros(1, n_hidden))
rnn = RNN(n_letters, n_hidden, n_categories)
out, next_hidden = rnn(input, hidden)  

===============================

In the last line, we do not use rnn.forward(input, hidden), but it works. This is because class RNN extends class nn.Module and the __call__() in class nn.Module implements the method forward() (http://pytorch.org/docs/master/_modules/torch/nn/modules/module.html#Module). Since __call__() is the default callable initialized function, we can call it using class name with parameters directly, here we do not need to use rnn.forwad(input, hidden)


===============================

  • Classic Classes

  • Class objects are described below.  When a class object is called, a new classinstance (also described below) is created and returned.  This implies a call tothe class’s __init__() method if it has one.  Any arguments are passed onto the __init__() method.  If there is no __init__() method, theclass must be called without arguments.

  • Class instances

  • Class instances are described below.  Class instances are callable only when theclass has a __call__() method; x(arguments) is a shorthand forx.__call__(arguments).

From: https://docs.python.org/2/reference/datamodel.html



https://blog.sciencenet.cn/blog-1969089-1084403.html

上一篇:Pytorch--save best model in the middle of training process
下一篇:DL--Examples of weight Initialization in deep neural network
收藏 IP: 60.191.2.*| 热度|

0

该博文允许注册用户评论 请点击登录 评论 (0 个评论)

数据加载中...

Archiver|手机版|科学网 ( 京ICP备07017567号-12 )

GMT+8, 2024-3-29 07:11

Powered by ScienceNet.cn

Copyright © 2007- 中国科学报社

返回顶部