Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose train loop to user code #1471

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

DhairyaLGandhi
Copy link
Member

@DhairyaLGandhi DhairyaLGandhi commented Jan 19, 2021

This is an example of doing something like #1461 where we can now add schedulers as prehooks and have the loss etc be available locally to the users to write better more efficient callbacks as well as manage when updates happen via pre hooks. This means pre hooks can throw SkipException and not update the params at all, or run arbitrary code in general.

Generally, it exposes the inner objects to be used by custom callbacks rather than pick them up from global scope or run into further hacks that can be bad for performance, which we have seen happen often.

of course, this is just implementing the changes to the loop itself, and the documentation and advertising the for loop more stands. This is to understand if that kind of API can be adopted, which seems clean enough and opens helps plug some holes in our training routines.

@DhairyaLGandhi
Copy link
Member Author

Simple usage example

struct Callback{T}
  losses::AbstractVector{T}
end

# l: loss at the datapoint
# ps: params (maybe can skip but good to have to avoid globals)
# gs: grads at the datapoint to inspect
# d: datapoint
# opt: modify optimiser based on some condition

(cb::Callback)(l, ps, gs, d, opt) = append!(cb.losses, l)

prehook(l, ps, gs, d, opt) = throw(Flux.Optimise.SkipException())

c = Callback(Float32[])
Flux.train!(loss, ps, data, opt, cb = [() -> (), c])
Flux.train!(loss, ps, data, opt, prehooks = prehook)

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If adding these hooks is the desired behavior, then I think the approach is fine.

But as I said in #1461, I don't think that we should try and make a simple hook system for train!. If a user wants to add a non-standard loop behavior (e.g. conditionals before the update, etc.), then it is easier and cleaner to define your own for loop in your own train! function.

update!(opt, ps, gs)
cb()
gs = back(l)
all(train_prehooks(l, ps, gs, d, opt)) && update!(opt, ps, gs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"prehooks" doesn't seem like a good name for what the intent of this hook is. It isn't immediately obvious that the hook can block a parameter update.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be happy to hear thoughts on a better name

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe "gradient check" or "update check"?

cb()
gs = back(l)
all(train_prehooks(l, ps, gs, d, opt)) && update!(opt, ps, gs)
cb(l, ps, gs, d, opt)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a nasty list of args...we're just passing everything into the hook regardless of whether it is necessary. I feel like this is further evidence that a simple hook system is not a scalable design.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough, but the idea is really to mirror the train api, really that was the intent here, but we can definitely not have to deal with that.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do want to add this PR to Flux, then maybe the argument list can be just l, d, gs (maybe we don't even need d in there). Cause ps and opt will be accessible in the scope calling train!, so the hook function can just close over them.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can expose a train context within which the step! takes place and the callback has access to the local scope too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does sorry?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore that comment. I posted before the page refreshed with your latest comment.

@DhairyaLGandhi
Copy link
Member Author

This isn't saying that we should be forcing more people into the train function, it's to say that we can make the train function more usable by itself. I am for the proper docs fwiw

@DhairyaLGandhi
Copy link
Member Author

I think the bit that I most want is exposing the loss via pullback and make it available to callbacks without extra compute

@darsnack
Copy link
Member

darsnack commented Jan 20, 2021

This isn't saying that we should be forcing more people into the train function, it's to say that we can make the train function more usable by itself. I am for the proper docs fwiw

I think this PR definitely makes train! more useful. But from an API design perspective, I wonder if it is the right choice to make train! do more than it already does. If you look at a library like fast.ai, then there are many places to insert code into the training loop. My worry is that adding hooks like this will take us down a rabbit-hole. And my fear with the rabbit-hole is that it will give us more to maintain.

I think the bit that I most want is exposing the loss via pullback and make it available to callbacks without extra compute

I do agree that this a major pain point of train!, and there isn't a single ML program that doesn't log the loss periodically. But I think there are other simpler approaches to this. The main issue is that there are some things introduced into train!'s scope that aren't available outside it. We could simply keep what this PR does with pullback and change the callback functions to accept two inputs: the current loss and the current gradients. Alternatively, we could have a logger keyword that calls a one-arg function that accepts the loss.

@DhairyaLGandhi
Copy link
Member Author

The thing is, from an api design perspective, our loop is pretty well engineered. There's few places to hook code into, and the most of the checkpointing can be easily done in the loss function and teach the ad to ignore it, which is pretty trivial with wrapping up things in a function and no gradding it.

I have also proposed adding a scope in the training loop which users can work with. This would be preferable I think, and is in line with #1017

@DhairyaLGandhi
Copy link
Member Author

I don't think it's a rabbit hole because there's not many places to add hooks to. You can have multiple prehooks and multiple callbacks already, and what you do with them is up to you. The library exposes the necessary functionality so to speak. Doing things before the grads? I'd be interested to see an example of that come up in use.

@DhairyaLGandhi
Copy link
Member Author

Note to self: check out some of the things fastai does and see what we can generalise as expected places for hooking up code.

@staticfloat
Copy link
Contributor

Hmmmm, this is a tough one. It strikes at the philosophy of what Flux truly is; is Flux a toolkit that you use to get things done, or is it an all-in-one application that you customize in order to perform one of the tasks it knows how to perform?

If it's the former, I support the "completely remove Flux.train!()" argument. It makes more sense for the docs to simply say "to train your model, you must now iterate over your dataset, calculate the loss at each step, and update your parameters, like this:". Yeah, this would be a hugely breaking change, but it also would reinforce the idea that Flux itself isn't the driver of the learning; it is merely giving you a bundle of tools to build architectures, push values through them, push losses back, and calculate gradients.

If it's the latter, then I think we should take a careful look around at what other people do and do our best to provide an ergonomic mapping of those hooks into the training loop. This will always be a moving target, because as new architectures and training regimen appear, this training loop will naturally shift and change. That's not necessarily a bad thing though, but it will create a natural rift between the users that "paint within the box" and those that are forced to escape the confines presented by Flux.train!()

Anecdotally, I have never finished a project with Flux.train!(), I've always eventually had to escape out into something more complicated, but that is likely an artifact of the fact that I am often doing 'researchy' things with Flux. I'm not really living in the nicely-manicured playground of supervised learning on datasets that fit into memory and have standard losses applied to them.

@mcabbott
Copy link
Member

Even if removing train! is too big a step, it could stick around only for the "Flux in one tweet!" use case. Although even there's not great, it takes 4 positional arguments and doesn't throw very friendly error messages if you get them in the wrong order.

The nice thing about being in Julia is that you don't have to learn some limited sub-language to interact with some incomprehensible back-end moster via its own callback API. You already know how to write a loop. If you write div(epochs,10)==0 && println(...) in it, you will never be confused about what is going to happen. Anyone remember exactly how Flux.stop is supposed to be used in a callback, without looking it up? Or you could write loss < 0.1 && break.

@darsnack
Copy link
Member

Pinging @ChrisRackauckas here to avoid the model-zoo issue getting side-tracked.

what I exactly want is for people to copy everywhere the for loop, they don't need the complication of train!.

If everyone has to copy the same piece of code around then the abstraction is wrong and we should change it. Maybe callbacks need to support more things, or there can be a few more switches. But telling everyone to roll it out by hand is only useful to exactly the same devs building the library.

The for-loop will always be a first class API, so you can't go wrong by using custom loops everywhere.

Indeed, but flexibility will always be limiting. There will always be some optimizers that require less flexibility (BFGS, KrylovTrustRegion), and so fully promoting the most open choice in a function sense is also limiting in another sense. The path forward is to try and tame what can be done and capture what users do into a simplified API to then specialize and help the code perform better, not let it run too loose. Isn't that the point of Flux in the first place since you can just define the layers by hand?


But telling everyone to roll it out by hand is only useful to exactly the same devs building the library.

Agreed, but the discussion about APIs is not that the highest-level abstraction that exists ever should be the for-loop. It's that it should be the highest-level in Flux.jl. You don't need to look further than the numerous uses of Flux.jl in Julia or the myriad of training-loop abstractions available in other frameworks to see that designing a good simplified API is hard to do.

It's undeniable that more keyword options in train! will make it more accessible. But this one function with a ton of arguments approach seems to me like trying to jam a square, star, and triangle peg into a round hole. I completely agree with you that we should be trying to capture use-cases into a simplified API. I'd just recommend that we don't do it in Flux.jl with train!.

@DhairyaLGandhi
Copy link
Member Author

one function with a ton of arguments approach

Again, to reiterate, that was meant to mirror an existing API which had already been around for a while. There is no reason to stick to it, and an alternative has been offered already, so hopefully that takes care of that.

It's that it should be the highest-level in Flux.jl.

Not really? If I wouldn't tell you that we use a for loop but can't quite do training by recursion or yield-based iteration, then it would be irresponsible of me as a package author. The limitations of the package, it's semantics and supported use cases is part of this high level abstraction as well.

For folks who have been engaged, I don't doubt for a second that they'd jump into writing hugely complex routines, but for someone who is writing their first 2-300 lines of Julia, I wouldn't want them to have to go through that.

Many savvy people on the slack or the zulip or anywhere would still need a hot minute to process the do syntax.

@darsnack
Copy link
Member

I feel like we're arguing in circles a bit. I agree with you that wrapping gradient/update! into a single step! will be a nice change. And I also agree with the design strategies in this PR for how to bring more complex functionality to train!.

But as mentioned, the main discussion here is philosophical, and my reservations are not based on smaller technical details. I think it's clear where I stand on the philosophical question 😄. I tried to articulate why in the original issue, and I think @mcabbott's comments here are a good summary of that.

@ToucheSir
Copy link
Member

I think this would be a great discussion topic for an ML call :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants