Source code for ecopann.fcnet

# -*- coding: utf-8 -*-

from . import sequence as seq
from . import nodeframe
import torch
import torch.nn as nn
import numpy as np


#%% fully connected single network
[docs]class FcNet(torch.nn.Module): """Get a fully connected network. Parameters ---------- node_in : int The number of the input nodes. node_out : int The number of the output nodes. hidden_layer : int The number of the hidden layers. nodes : None or list If list, it should be a collection of nodes of the network, e.g. [node_in, node_hidden1, node_hidden2, ..., node_out] activation_func : str Activation function. """ def __init__(self, node_in=2000, node_out=6, hidden_layer=3, nodes=None, activation_func='rrelu'): super(FcNet, self).__init__() if nodes is None: nodes = nodeframe.decreasingNode(node_in=node_in, node_out=node_out, hidden_layer=hidden_layer, get_allNode=True) self.fc = seq.LinearSeq(nodes, mainActive=activation_func, finalActive='None', mainBN=True, finalBN=False, mainDropout='None', finalDropout='None').get_seq()
[docs] def forward(self, x): x = self.fc(x) return x
#%% multibranch network
[docs]def split_nodes(nodes, weight=[]): nodes_new = [[] for i in range(len(weight))] for i in range(len(weight)): for j in range(len(nodes)): nodes_new[i].append(round(nodes[j]*weight[i])) return nodes_new
[docs]class MultiBranchFcNet(nn.Module): """Get a multibranch network. Parameters ---------- nodes_in : list The number of the input nodes for each branch. e.g. [node_in_branch1, node_in_branch2, ...] node_out : int The number of the output nodes. branch_hiddenLayer : int The number of the hidden layers for the branch part. trunk_hiddenLayer : int The number of the hidden layers for the trunk part. nodes_all : list The number of nodes of the multibranch network. e.g. [nodes_branch1, nodes_branch2, ..., nodes_trunk] """ def __init__(self, nodes_in=[100,100,20], node_out=6, branch_hiddenLayer=1, trunk_hiddenLayer=3, nodes_all=None, activation_func='rrelu'): super(MultiBranchFcNet, self).__init__() if nodes_all is None: # method 1 nodes_all = [] branch_outs = [] fc_hidden = branch_hiddenLayer*2 + 1 # fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 #also works, but not necessary for i in range(len(nodes_in)): fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=node_out, hidden_layer=fc_hidden, get_allNode=True) branch_node = fc_node[:branch_hiddenLayer+2] nodes_all.append(branch_node) branch_outs.append(branch_node[-1]) nodes_all.append(nodeframe.decreasingNode(node_in=sum(branch_outs), node_out=node_out, hidden_layer=trunk_hiddenLayer, get_allNode=True)) # #method 2 # nodes_all = [] # branch_outs = [] # fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 # fc_hidd_node = nodeframe.decreasingNode(node_in=sum(nodes_in), node_out=node_out, hidden_layer=fc_hidden, get_allNode=False) # fc_hidd_node_split = split_nodes(fc_hidd_node[:branch_hiddenLayer+1], weight=[nodes_in[i]/sum(nodes_in) for i in range(len(nodes_in))]) # for i in range(len(nodes_in)): # branch_node = [nodes_in[i]] + fc_hidd_node_split[i] # nodes_all.append(branch_node) # branch_outs.append(branch_node[-1]) # trunk_node = [sum(branch_outs)] + list(fc_hidd_node[branch_hiddenLayer+1:]) + [node_out] # nodes_all.append(trunk_node) # #method 3 # nodes_all = [] # nodes_comb = [] # fc_hidden = branch_hiddenLayer + trunk_hiddenLayer + 1 # for i in range(len(nodes_in)): # fc_node = nodeframe.decreasingNode(node_in=nodes_in[i], node_out=node_out, hidden_layer=fc_hidden, get_allNode=True) # branch_node = fc_node[:branch_hiddenLayer+2] # nodes_all.append(branch_node) # nodes_comb.append(fc_node[branch_hiddenLayer+1:-1]) # trunk_node = list(np.sum(np.array(nodes_comb), axis=0)) + [node_out] # nodes_all.append(trunk_node) self.branch_n = len(nodes_all) - 1 for i in range(self.branch_n): exec("self.branch%s = seq.LinearSeq(nodes_all[i],mainActive=activation_func,finalActive=activation_func,mainBN=True,finalBN=True,mainDropout='None',finalDropout='None').get_seq()"%(i+1)) self.trunk = seq.LinearSeq(nodes_all[-1],mainActive=activation_func,finalActive='None',mainBN=True,finalBN=False,mainDropout='None',finalDropout='None').get_seq()
[docs] def forward(self, x_all): x1 = self.branch1(x_all[0]) x_comb = x1 for i in range(1, self.branch_n-1+1): x_n = eval('self.branch%s(x_all[i])'%(i+1))#Note:i & i+1 x_comb = torch.cat((x_comb, x_n),1) x = self.trunk(x_comb) return x