From bd74557c2da3602120cfd3d44a351d0e441df0b4 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Mon, 18 Jan 2021 10:59:13 +0100 Subject: [PATCH] Fix GarbageCollect callback (#57) * Fix GarbageCollect callback * add regression test case for GarbageCollect * fix for platforms other than linux * remove ")" * fix GarbageCollect construction * append --- src/callbacks/callbacks.jl | 19 ++++++++++++++++--- test/callbacks/garbagecollect.jl | 8 ++++++++ test/runtests.jl | 1 + 3 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 test/callbacks/garbagecollect.jl diff --git a/src/callbacks/callbacks.jl b/src/callbacks/callbacks.jl index e9a5fa373..c86b5bb15 100644 --- a/src/callbacks/callbacks.jl +++ b/src/callbacks/callbacks.jl @@ -122,14 +122,27 @@ function on(::BatchBegin, ::Phase, cb::ToGPU, learner) end -garbagecollect() = (GC.gc(); ccall(:malloc_trim, Cvoid, (Cint,), 0)) +function garbagecollect() + GC.gc() + if Base.Sys.islinux() + ccall(:malloc_trim, Cvoid, (Cint,), 0) + end +end + """ GarbageCollect(nsteps) Every `nsteps` steps, forces garbage collection. -Use this if you get memory leaks from, for example, parallel data loading. +Use this if you get memory leaks from, for example, +parallel data loading. + +Performs an additional C-call on Linux systems that can +sometimes help. """ function GarbageCollect(nsteps::Int = 100) - return throttle(CustomCallback((learner) -> garbagecollect(), BatchEnd), freq = nsteps) + return throttle( + CustomCallback((learner) -> garbagecollect(), BatchEnd, Phase), + BatchEnd(), + freq = nsteps) end diff --git a/test/callbacks/garbagecollect.jl b/test/callbacks/garbagecollect.jl new file mode 100644 index 000000000..4e53c90e5 --- /dev/null +++ b/test/callbacks/garbagecollect.jl @@ -0,0 +1,8 @@ + +include("../imports.jl") + +@testset ExtendedTestSet "`GarbageCollect`" begin + cb = GarbageCollect() + learner = testlearner(Recorder(), Metrics(), cb) + @test_nowarn fit!(learner, 1) +end diff --git a/test/runtests.jl b/test/runtests.jl index 7b2ba9fc6..844438bfa 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,5 +11,6 @@ include("./imports.jl") include("./callbacks/recorder.jl") include("./callbacks/scheduler.jl") include("./callbacks/checkpointer.jl") + include("./callbacks/garbagecollect.jl") include("./callbacks/sanitycheck.jl") end