# -*- coding: utf-8 -*-
import torch.nn as nn
import torch.nn.functional as F
#%% activation functions
[docs]def relu():
#here 'inplace=True' is used to save GPU memory
return nn.ReLU(inplace=True)
[docs]def leakyrelu():
return nn.LeakyReLU(inplace=True)
[docs]def prelu():
return nn.PReLU()
[docs]def rrelu():
return nn.RReLU(inplace=True)
[docs]def relu6():
return nn.ReLU6(inplace=True)
[docs]def elu():
return nn.ELU(inplace=True)
[docs]class ELU_1(nn.Module):
def __init__(self):
super(ELU_1, self).__init__()
[docs] def forward(self, x):
x = F.elu(x)
x = x + 1
return x
[docs]def elu_1():
return ELU_1()
[docs]def celu():
return nn.CELU(inplace=True)
[docs]def selu():
return nn.SELU(inplace=True)
[docs]def silu():
return nn.SiLU(inplace=True)
[docs]def sigmoid():
return nn.Sigmoid()
[docs]def logsigmoid():
return nn.LogSigmoid()
[docs]def tanh():
return nn.Tanh()
[docs]def tanhshrink():
return nn.Tanhshrink()
[docs]def softsign():
return nn.Softsign()
[docs]def softplus():
return nn.Softplus()
[docs]class Softplus_1(nn.Module):
def __init__(self):
super(Softplus_1, self).__init__()
[docs] def forward(self, x):
x = F.softplus(x)
x = x - 1
return x
[docs]def softplus_1():
return Softplus_1()
[docs]class Softplus_2(nn.Module):
def __init__(self):
super(Softplus_2, self).__init__()
[docs] def forward(self, x):
x = F.softplus(x)
x = x - 2
return x
[docs]def softplus_2():
return Softplus_2()
[docs]class Sigmoid_1(nn.Module):
def __init__(self):
super(Sigmoid_1, self).__init__()
[docs] def forward(self, x):
x = F.sigmoid(x)
x = x - 0.5
return x
[docs]def sigmoid_1():
return Sigmoid_1()
[docs]def activation(activation_name='rrelu'):
"""Activation functions.
Parameters
----------
activation_name : str, optional
The name of activation function, which can be 'relu', 'leakyrelu', 'prelu', 'rrelu',
'relu6', 'elu', 'celu', 'selu', 'silu', 'sigmoid', 'logsigmoid', 'tanh', 'tanhshrink', 'softsign', or 'softplus'. Default: 'rrelu'
Returns
-------
object
Activation functions.
Note
----
Although many activation functions are available, the recommended activation function is 'rrelu'.
"""
return eval('%s()'%activation_name)
#%% Pooling
[docs]def maxPool1d(kernel_size):
return nn.MaxPool1d(kernel_size)
[docs]def maxPool2d(kernel_size):
return nn.MaxPool2d(kernel_size)
[docs]def maxPool3d(kernel_size):
return nn.MaxPool3d(kernel_size)
[docs]def avgPool1d(kernel_size):
return nn.AvgPool1d(kernel_size)
[docs]def avgPool2d(kernel_size):
return nn.AvgPool2d(kernel_size)
[docs]def avgPool3d(kernel_size):
return nn.AvgPool3d(kernel_size)
[docs]def pooling(pool_name='maxPool2d', kernel_size=2):
return eval('%s(kernel_size)'%pool_name)
#%% Dropout
[docs]def dropout():
return nn.Dropout(inplace=False)
[docs]def dropout2d():
return nn.Dropout2d(inplace=False)
[docs]def dropout3d():
return nn.Dropout3d(inplace=False)
[docs]def get_dropout(drouput_name='dropout'):
"""Get the dropout."""
return eval('%s()'%drouput_name)