python – How can I recreate this PyTorch layer/network in Keras?

I’m currently working on a project and part of it is reimplementing a model written for a paper in PyTorch in Keras. The overal model classifies proteins based on three elements of their properties: sequence, interaction with other proteins, and domains in their sequence (motifs). The part I’m working on recreating currently is the Protein-Protein Interaction part. Firstly, the input vectors simply go through some fully connected layers which is easy enough to implement in keras. However, the outputs from this model are fed into a ‘weight classifier model’ which applies a binary mask matrix to inputs using a layer created specifically for this model using PyTorch’s nn.functional API.

Here is the code I am struggling to implement in keras:

class Weight_classifier(nn.Module):
    def __init__(self, func):
        super(Weight_classifier, self).__init__()
        # self.weight_layer = nn.Linear(OUT_nodes(func)*3, OUT_nodes(func))
        self.weight_layer = MaskedLinear(OUT_nodes(func)*3, OUT_nodes(func), 'data/{}_maskmatrix.csv'.format(func)).cuda()
        self.outlayer= nn.Linear(OUT_nodes(func), OUT_nodes(func))

    def forward(self, weight_features):
        weight_out = self.weight_layer(weight_features)
        # weight_out = F.sigmoid(weight_out)
        weight_out = F.relu(weight_out)
        weight_out = F.sigmoid(self.outlayer(weight_out))
        return weight_out


class MaskedLinear(nn.Linear):
    def __init__(self, in_features, out_features, relation_file, bias=True):
        super(MaskedLinear, self).__init__(in_features, out_features, bias)

        mask = self.readRelationFromFile(relation_file)
        self.register_buffer('mask', mask)
        self.iter = 0

    def forward(self, input):
        masked_weight = self.weight * self.mask
        return F.linear(input, masked_weight, self.bias)

    def readRelationFromFile(self, relation_file):
        mask = ()
        with open(relation_file, 'r') as f:
            for line in f:
                l = (int(x) for x in line.strip().split(','))
                for item in l:
                    assert item == 1 or item == 0  # relation 只能为0或者1
                mask.append(l)
        return Variable(torch.Tensor(mask))

And this is the paper I am working to, it contains several diagrams and explanations of the models if I have not explained the issue sufficiently.

Many thanks.