Skip to content

Commit

Permalink
Kill process group instead of iterator of pids in shutdown hook (inte…
Browse files Browse the repository at this point in the history
…l-analytics#4494)

* kill process group instead of process iter

* change name

* change name

* update doc

* fix style

* change to string
  • Loading branch information
shanyu-sys committed Aug 18, 2021
1 parent 4ac8f87 commit 36fe05d
Showing 1 changed file with 19 additions and 29 deletions.
48 changes: 19 additions & 29 deletions net/python/PythonZooNet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -125,43 +125,33 @@ class PythonZooNet[T: ClassTag](implicit ev: TensorNumeric[T]) extends PythonZoo
TFNet(path, config)
}

val processToBeKill = new CopyOnWriteArrayList[String]()
var processGpToBeKill: String = ""
registerKiller()

private def killPids(killingList: JList[String], killCommand: String): Unit = {
try {
val iter = killingList.iterator()
while(iter.hasNext) {
val pid = iter.next()
println("JVM is stopping process: " + pid)
val process = Runtime.getRuntime().exec(killCommand + pid)
process.waitFor(2, TimeUnit.SECONDS)
if (process.exitValue() == 0) {
iter.remove()
}
}
} catch {
case e : Exception =>
}
private def killPgid(pgid: String, killCommand: String): Boolean = {
println("JVM is stopping process group: " + pgid)
val process = Runtime.getRuntime().exec(killCommand + pgid)
process.waitFor(2, TimeUnit.SECONDS)
process.exitValue() == 0
}

private def registerKiller(): Unit = {
Logger.getLogger("py4j.reflection.ReflectionEngine").setLevel(Level.ERROR)
Logger.getLogger("py4j.GatewayConnection").setLevel(Level.ERROR)
Runtime.getRuntime().addShutdownHook(new Thread {
override def run(): Unit = {
// Give it a chance to be gracefully killed
killPids(processToBeKill, "kill ")
if (!processToBeKill.isEmpty) {
Thread.sleep(2000)
killPids(processToBeKill, "kill -9")
}
}
})
}

def jvmGuardRegisterPids(pids: ArrayList[Integer]): Unit = {
pids.asScala.foreach(pid => processToBeKill.add(pid + ""))
override def run(): Unit = {
if (processGpToBeKill == "") return
// Give it a chance to be gracefully killed
val success = killPgid(processGpToBeKill, "kill -- -")
if (!success) {
killPgid(processGpToBeKill, "kill -9 -")
}
}
})
}

def jvmGuardRegisterPgid(gpid: Int): Unit = {
this.processGpToBeKill = gpid.toString
}

def getModuleExtraParameters(model: AbstractModule[_, _, T]): Array[JTensor] = {
Expand Down

0 comments on commit 36fe05d

Please sign in to comment.