Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential bug of RNN training flow #2455

Open
Bonjour-Lemonde opened this issue Jun 8, 2024 · 0 comments
Open

Potential bug of RNN training flow #2455

Bonjour-Lemonde opened this issue Jun 8, 2024 · 0 comments
Labels

Comments

@Bonjour-Lemonde
Copy link

I had a strange problem using Flux RNN, my training data contains myX:one-hot vector, and myY:a number. The training data shown below worked very well using feedforward network(epoch=20,R2=0.9), but very low using Flux RNN(epoch=200,R2=0.2), what’s more, I am sure it is not the model architecture, because it trained well for other training data(R2=1)refer to as otherX otherY.
I also found that the problem is all about my X, because my RNN network worked also well on [otherX, myY],[otherX,otherY], but not [myX,myY],[myX,otherY]. Thus I suggest it is associated with some bug of RNN training flow.
below is the code. Hope anyone could help! Thanks!

# julia version 1.10.3
using Flux
oriX=["ATAGGAGGCGCGTGACAGAGTCCCTGTCCAATTACCTACCCAAA", "ATAGGAGGCGCAAGAGAGAAGCCCAGACCAATAACCTACCCAAA", "ATAGGAGGCTAACGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGAGGCGCCTGAGAAAAGCCCAGACCAATTACCTACCCAAA", "ATAGGACGCGCATGAGAGATGCCCTGACCAATTACCTACCCAAA", "ATAGGTGGTGCATGAGATAAGCACAGCTCAATACCCTACCCAAA", "ATAGGAGACGCAGGGGCGAAGCCCGGACCATTTACCTACCCAAA", "ATAGGTGGTGCATGAGATAATCCCAGACCAATTACCTACCCAAA", "ATAGGAGGCTCATGAGATAAGGCTTGACCAATTACCTACCCAAA", "ATAGGAGGCTCATGAGAGCAGCCCAGATTAATTACCTACCCAAA", "ATAGGAGGCGCGTGAGAGAGGACCCGACCAATTACTCACCCAAC", "ATAGGCAGCGCATGAGAGAAGCCCAGACCAATTACCTACTCAAC", "ATAGGAGGCTAACGAGAGAAGCCCAGACCACTTACCTACCCAAA", "ATAGGAGGCGCATGAGAAAAGCCCCGCCCAATTACCTACCCAAG", "ATAGGCGGCGCTTGAGAGAAGCCCATACCCATTACCTACCCAAA", "ATAGGCGGCACATGAGACAAGCCGAAGCCAATTACCTACCCAAA", "ATAGGCTGCGCATGAGAGAAGGCGACACAAATTACCTACCCAAA", "ATAGGCGGCACATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGTGCAAGAGAGACGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGTATGAGAGAAGCCCAGCTCAATTACCTACCCAAA", "ATAGGAGGCGCATGAGATAACCCACCACCAAGTACCTACTCAAA", "ATAGGTGGCGCATGAGAGCACCTCAGACGAAGTACCTACCCAAA", "ATAGGCGGCGCATGAGATAAGCCTAGACCATTTACCTACCCAAA", "ATAGGTGGCGCATGAGATAAGCGCATAACACCAACTTACCCAAC", "ATAGGCGGCGCATGAGACAAATCCAGGCCAATTATCTACCCAAA", "ATAGGCGGCTCATGAGATAAGCCCAGACCAAATACCTACCCAAA", "ATAGGAGGCGCATGAGAGAATCCCAAACCAATTCCCTACAAACC", "ATAGGCGGCGCATGAGACAAGCCCATACCAATTACCTACCCAAA", "ATAGGTGCGACTTGAGAGATGCCCATATCGACTACCTACCCGAA", "ATAGGCGGTGCATGACTGACGCCCAGACCAATTACCTACCCAAA", "ATAGGGGGCTAATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGAGGCGCATGAGATAAGCCCAGACCAATTACCTACCCCGA", "ATAGGAGGTGCACGAGAGTTGCCCAGACCAATTAACTTCCCAAA", "ATAGGCGGCGCATGAGAAAAGCCCAGACCAATTACCTACCCAAA", "ATAGGTGGCCCGCGAGTTAGGACGAGACTAATTCCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCCATTACCTACCCAAA", "ATAGGCGGCGGACGAGAGAAGCCCAGACCAATTACCTACCCATA", "ATAGGTGGCGCATGAAATAAAACCAGTGCAATTACCTACCCATA", "ATAGGCGACGCATGAGAAAAGCCCAGACCCATTACCTACCCAAA", "ATAGGCTGCGCATGAGAGAAGCCCAGACCAATTATCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTCCCTACCCAAA", "ATAGGAGTCGCCTGACAGATGACCATACCAATTACCTATCCAAA", "ATAGGCCGCGGATTAGACAACATCTTACCAATTCCCTGCCCAAA", "ATAGGCGGTGCAAGAGCGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGAAGCTAAAGGGAGTAGCTCAGTACAGTTAACTACCCCAA", "ATAGGCCGCGCATGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAGTTACCTACCCAAA", "ATAGGAAGCGCATGAGAAAAGCCCAGACAAATCACCTACCGAAC", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAGTTACCTACCCAAC", "ATAGGCGGCACATGAGCGCAGCCCAGTCCAATTACCTACCCAAA", "ATAGGCGGCGCATGACACAGGCCCAGACCAATGACCTACCCAAA", "ATAGGCAGCGCATGAGAGAAGCCCAGACCAATTACCTACTCAAA", "ATAGGCGACGAATGAGTGAAGCCCACATTAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAACTACATACCCAAA", "ATAGGCGGCGCATGAGACAAATCCAGGCCAATTACCTACCCAAA", "ATAGGCCGCCGATGAGAAAAGCCCGACGCACTTAACTACCCGAA", "ATAGGCGGTGCATGAGAGAGACGCAGTGCAAATACCTACCCAAA", "ATAGGCGGCGGATTAGAGAAGTCCAGACTATTTACCTACCCAAA", "ATAGGCGGCGAATGAGAGAAGCCCAGACCAATTACCTACCCAGA", "ATAGGCGGCGCATGAGATAAGCCCAGTCGAATTACCTACCCAAA", "ATAGGCCGCGCATGAGAAAAGCCTAGACCAATTGCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTACCTACCCACA", "ATAGGCGACGCATGAGAGAAGCCCAGACGAATTACCTACCCAAA", "ATAGGTCCAGCATTAAGGCAGGCCAGACCCTTTACCTACCCAAA", "ATAGGAGGGACATGCGATAGGCTCAGACCAATTTCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGGCCAATTAACTACCCAAA", "ATAGGCGGCGCATGAGAGTAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGGGAAGCCCAGACCCATTCCCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTAGCTACCCAAA", "ATAGGCGACGTATGAGAGAATCCCTGACCATTTACCTACCCAAA", "ATAGGCGGCGCATGATATAAGCCCAGCCCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACATATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCTAAGACCCATTACCTACCCAAA", "ATAGGCCGCGCATGAGAGAAGCTCAGACCCATTACCTACCCAAA", "ATAGGTGGCGCATGAGAGAAGCCCAGACCAATTACCTACACAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGACCAATTACCTGCCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGTCCGATTTACTACCCAAG", "ATAGGCGGAGCATATGAGATGCCCAGACCAAATACCTACCCAAA", "ATAGGCGGCGCATGACAGAAGCCCTGACCGATAACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCCCAGAGCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAAGCTCAGACCAATTACCTACCCAAA", "ATAGGTGTCGCTTGAAAATAGCCCAGACGAATTACCTACCCAAA", "ATAGGCGGCGCATGAGCGTTGCACAGACCAATTACCTACCCAAA", "ATAGGCGGCGTATGAGAGAAGCGCGGCCCAATTACCTACCCAAA", "ATAGGCGGCGCATGAGAGAGGCCCTGACCAAATAACTACCCAAA", "ATAGGCGGCTCATGAGAGAAGCCCAGACCAACTGCCTACCCAAA", "ATAGGCAGCGCATGAGTGAAGCCCAGACCAGTTACCTCCCCAAA", "ATAGGCAGCAGATGACAGTAGCCCCGACCAAATTACTACTCAAA", "ATAGGCGGCGCATGAGAGGAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCCCAGGAGAGCATCCAAGACCAATTACCTACCCAAA", "ATAGGCGGCGCATACGAGAAGCGCAGACTAATTACCTACCCAAA", "ATAGGCGGCGCATGACATAAGCCTAGATCAATTACCTACCCAAG", "ATAGGCGGCACATGACACAGGCCCAGACCAATGACCTACCCAAA", "ATAGGCGGCGCAGGAGAGAAGCCCAGACCAATTACCTACCCAAA", "ATAGGCGGCGCGTAAGAGAAGCCTAGACAAATTACCTACCGAAA", "ATAGGCGACGCATCTGCGAATCCCACACCAATTACCTACCCGAA", "ATAGGCGACGCATGAGAGAAGCCCAGACCAATTAACTATCCATC", "ATAGGCGGCGCATGAGAGCAGCCCAAACCAATTGCCTACCCAAA", "ATAGGCGGCGCGTGAGAGTAGCCCTGACCAGTTTCCTGCCCAAA"]

ytrain = Float32[3.5539734 2.7561886 2.8113236 2.7176633 2.7606876 2.6220577 2.3115172 2.4817004 1.9276409 2.5030751 1.8989105 1.6381569 3.112245 1.9992489 1.8364053 2.1545537 1.8151137 1.9761252 2.0710406 1.8238684 1.5769696 2.2978039 2.0652819 1.6795048 1.4621212 1.8550924 1.3247801 2.0052798 1.5950761 2.1166725 1.1718857 1.443101 1.4597932 2.0249891 1.659723 1.7782362 1.3042092 1.3574703 1.7164876 1.4561561 1.6886593 1.5327756 1.3272716 1.2478243 1.6909612 0.9371975 1.3504946 1.7342895 1.0429348 1.6653012 1.6186994 1.6343817 1.1894267 1.6500783 1.1910686 1.5190029 0.93479043 1.5677443 1.2633525 1.4441946 1.8120437 1.6296253 1.3869075 1.7520566 1.247555 1.4638474 1.4413416 1.5457458 1.3801547 1.312296 0.96203357 1.571632 0.2540248 1.0096036 0.8302187 0.73939687 1.4816427 1.1275434 1.1184824 1.3548776 1.3924822 1.2923665 0.9824461 1.2085876 1.3007151 1.4721189 1.3741052 0.7266495 0.5496262 1.3403294 0.931344 0.7101498 1.3628994 1.8999943 1.2633573 1.1379782 0.6508444 0.5403087 1.435614 1.319527]

Xtrain = [map(x -> Flux.onehot.(x, "ACGT"), collect(join(oriX[idx]))) for idx in 1:100]


Xtrain_ffnn = hcat([vcat(x...) for x ∈ Xtrain]...)

# lossFunction and accuracy
function accuracy(m, X, y)
    Flux.reset!(m) # Only important for recurrent network
    R²(y, m(X))
end

function lossFun(m, X, y)
    Flux.reset!(m) # Only important for recurrent network
    Flux.mse(m(X),y)
end

# first learn the train data on feedforward
ffnn = Chain(
    Dense(176 => 128, relu),
    Dense(128 => 128, relu),
    Dense(128 => 1)
)
opt_ffnn = ADAM()
θ_ffnn = Flux.params(ffnn) # Keep track of the trainable parameters
epochs = 100 # Train the model for 100 epochs
for epoch ∈ 1:epochs
    # Train the model using batches of size 32
    for idx ∈ Iterators.partition(shuffle(1:size(Xtrain_ffnn, 2)), 32)
        X, y = Xtrain_ffnn[:, idx], ytrain[:, idx]
        ∇ = gradient(θ_ffnn) do 
            # Flux.logitcrossentropy(ffnn(X), y)
            Flux.mse(ffnn(X),y)
        end
        Flux.update!(opt_ffnn, θ_ffnn, ∇)
    end
    X, y = Xtrain_ffnn, ytrain
    @show accuracy(ffnn, Xtrain_ffnn, ytrain)
end

# then learn the train data by seq2one(RNN)

struct Seq2One
    rnn # Recurrent layers
    fc  # Fully-connected layers
end
Flux.@functor Seq2One # Make the structure differentiable
# Define behavior of passing data to an instance of this struct
function (m::Seq2One)(X)
    # Run recurrent layers on all but final data point
    [m.rnn(x) for x ∈ X[1:end-1]]
    # Pass last data point through both recurrent and fully-connected layers
    m.fc(m.rnn(X[end])) 
end

# Create the sequence-to-one network using a similar layer architecture as above
seq2one = Seq2One(
    Chain(
        RNN(4 => 128, relu),
        RNN(128 => 128, relu)
    ),
    Dense(128 => 1)
)
opt_rnn = ADAM()
θ_rnn = Flux.params(seq2one) # Keep track of the trainable parameters
epochs = 200 # Train the model for 10 epochs
for epoch ∈ 1:epochs
    # Train the model using batches of size 32
    for idx ∈ Iterators.partition(shuffle(1:size(Xtrain, 1)), 32)
        Flux.reset!(seq2one) # Reset hidden state
        X, y = Xtrain[idx], ytrain[:, idx]
        X = [hcat([x[i] for x ∈ X]...) for i ∈ 1:seqlen] # Reshape X for RNN format
        ∇ = gradient(θ_rnn) do 
            # Flux.logitcrossentropy(seq2one(X), y)
            Flux.mse(seq2one(X),y)
        end
        Flux.update!(opt_rnn, θ_rnn, ∇)
    end
    X, y = [hcat([x[i] for x ∈ Xtrain]...) for i ∈ 1:seqlen], ytrain
    @show accuracy(seq2one, X, y)
end
@mcabbott mcabbott added the RNN label Jun 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants