Skip to content

Commit

Permalink
Merge pull request xaynetwork#43 from xainag/memory-leak
Browse files Browse the repository at this point in the history
Fix Python memory leak
  • Loading branch information
little-dude committed Mar 6, 2020
2 parents 3c78329 + bf6da21 commit 0e8f26d
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 29 deletions.
1 change: 1 addition & 0 deletions python/aggregators/xain_aggregators/weighted_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

class Aggregator(AggregatorABC):
def __init__(self):
logging.basicConfig(level=logging.INFO)
LOG.info("initializing aggregator")
self.global_weights: np.ndarray = None
self.weights: List[np.ndarray] = []
Expand Down
12 changes: 7 additions & 5 deletions python/client_examples/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ def train_round(self, training_input: TrainingInput) -> TrainingResult:
weights=training_input.weights, shapes=self.model_shapes
)

# FIXME: the epoch should come from the aggregator but I don't
# understand what it is exactly. According to Jan it's only
# used for metrics so I think it's ok to hardcode this to 10.
for _ in range(0, 10):
self.model.fit(x=self.trainset, verbose=2, shuffle=False)
# Uncomment this if you want to train the model. However, this
# is a dummy example with random data so there's no point in
# actually training.
#
# epochs = 10
# for _ in range(0, epochs):
# self.model.fit(x=self.trainset, verbose=2, shuffle=False)

# return the updated model weights and the number of training samples
return TrainingResult(self.get_tensorflow_weights(), self.train_samples)
Expand Down
2 changes: 1 addition & 1 deletion python/sdk/xain_sdk/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import requests

LOG = logging.getLogger("participant")
LOG = logging.getLogger("xain-sdk.http")


def log_headers(headers):
Expand Down
11 changes: 11 additions & 0 deletions python/sdk/xain_sdk/participant.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,43 @@ def __init__(self, coordinator_url: str, participant: ParticipantABC):
def run(self):
self.rendez_vous()
while True:
LOG.info("waiting for being selected")
with self.state_record:
self.state_record.wait_until_selected_or_done()
state, _ = self.state_record.lookup()

if state == State.DONE:
LOG.info("state changed: DONE")
return

if state == State.TRAINING:
LOG.info("state changed: TRAINING")
try:
LOG.info("requesting training information to the coordinator")
self.aggregator_client = self.coordinator_client.start_training()
except StartTrainingRejected:
LOG.warning("start training request rejected")
with self.state_record:
self.state_record.set_state(State.WAITING)

LOG.info("downloading global weights from the aggregator")
data = self.aggregator_client.download()
LOG.info("retrieved training data (length: %d bytes)", len(data))
training_input = self.participant.deserialize_training_input(data)

if training_input.is_initialization_round():
LOG.info("initializing the weights")
result = self.participant.init_weights()
else:
LOG.info("training")
result = self.participant.train_round(training_input)
assert isinstance(result, TrainingResultABC)
LOG.info("training finished")

LOG.info("sending the local weights to the aggregator")
self.aggregator_client.upload(result.tobytes())

LOG.info("going back to WAITING state")
with self.state_record:
self.state_record.set_state(State.WAITING)

Expand Down
4 changes: 3 additions & 1 deletion python/sdk/xain_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging


def configure_logging(level=logging.DEBUG):
def configure_logging(level=logging.INFO):
logging.basicConfig(level=level, format="%(asctime)s %(levelname)-8s %(message)s")
http_logger = logging.getLogger("xain-sdk.http")
http_logger.setLevel(logging.WARNING)
logging.getLogger("requests").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)

Expand Down
61 changes: 40 additions & 21 deletions rust/src/aggregator/py_aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,61 +14,82 @@ use tokio::{
use crate::aggregator::{service::Aggregator, settings::PythonAggregatorSettings};
use pyo3::{
types::{PyBytes, PyModule},
PyObject, PyResult, Python, ToPyObject,
GILGuard, PyObject, PyResult, Python, ToPyObject,
};

pub struct PyAggregator<'py> {
py: Python<'py>,
pub struct PyAggregator {
gil: Option<GILGuard>,
aggregator: PyObject,
}

impl<'py> PyAggregator<'py> {
pub fn load(py: Python<'py>, settings: PythonAggregatorSettings) -> PyResult<Self> {
impl PyAggregator {
pub fn load(settings: PythonAggregatorSettings) -> PyResult<Self> {
let gil = Python::acquire_gil();
let py = gil.python();
// FIXME: make this configurable
let module = PyModule::import(py, &settings.module)
.map_err(|e| e.print(py))
.unwrap();
let aggregator = module.call0(&settings.class).unwrap().to_object(py);
Ok(Self { py, aggregator })
Ok(Self {
gil: Some(gil),
aggregator,
})
}

pub fn aggregate(&self) -> PyResult<Bytes> {
pub fn aggregate(&mut self) -> PyResult<Bytes> {
info!("PyAggregator: running aggregation");
let py = self.get_py();
let result = self
.aggregator
.call_method0(self.py, "aggregate")?
.extract::<Vec<u8>>(self.py)
.call_method0(py, "aggregate")?
.extract::<Vec<u8>>(py)
.map(Bytes::from)?;
info!("PyAggregator: finished aggregation");
self.re_acquire_gil();
Ok(result)
}

/// Release the GIL so that python's garbage collector runs
fn re_acquire_gil(&mut self) {
self.gil = None;
self.gil = Some(Python::acquire_gil());
}

pub fn get_global_weights(&self) -> PyResult<Bytes> {
let py = self.get_py();
Ok(self
.aggregator
.call_method0(self.py, "get_global_weights")?
.extract::<Vec<u8>>(self.py)
.call_method0(py, "get_global_weights")?
.extract::<Vec<u8>>(py)
.map(Bytes::from)?)
}

pub fn add_weights(&self, local_weights: &[u8]) -> PyResult<Result<(), ()>> {
info!("PyAggregator: adding weights");
let py_bytes = PyBytes::new(self.py, local_weights);
let py = self.get_py();
let py_bytes = PyBytes::new(py, local_weights);
let args = (py_bytes,);
let result = self
.aggregator
.call_method1(self.py, "add_weights", args)?
.extract::<bool>(self.py)?
.call_method1(py, "add_weights", args)?
.extract::<bool>(py)?
.then_some(())
.ok_or(());
info!("PyAggregator: done adding weights");
Ok(result)
}

pub fn reset(&self, global_weights: &[u8]) -> PyResult<()> {
let py_bytes = PyBytes::new(self.py, global_weights);
pub fn get_py(&self) -> Python<'_> {
self.gil.as_ref().unwrap().python()
}

pub fn reset(&mut self, global_weights: &[u8]) -> PyResult<()> {
let py = self.get_py();
let py_bytes = PyBytes::new(py, global_weights);
let args = (py_bytes,);
self.aggregator.call_method1(self.py, "reset", args)?;
self.aggregator.call_method1(py, "reset", args)?;
self.re_acquire_gil();
Ok(())
}
}
Expand Down Expand Up @@ -123,14 +144,12 @@ async fn py_aggregator(
mut aggregate_requests: RequestRx<(), Weights>,
mut add_weights_requests: RequestRx<Weights, ()>,
) {
let gil = Python::acquire_gil();
let py = gil.python();
let aggregator = PyAggregator::load(py, settings).unwrap();
let mut aggregator = PyAggregator::load(settings).unwrap();

loop {
select! {
Some(((), resp_tx)) = aggregate_requests.recv() => {
let weights = aggregator.aggregate().unwrap();
let weights = aggregator.aggregate().map_err(|e| error!("{:?}", e)).unwrap();
if resp_tx.send(weights).is_err() {
warn!("cannot send aggregate response, receiver has been dropped");
return;
Expand Down
2 changes: 1 addition & 1 deletion rust/src/bin/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use xain_fl::aggregator::{
api,
py_aggregator::spawn_py_aggregator,
service::AggregatorService,
settings::{AggregationSettings, PythonAggregatorSettings, Settings},
settings::{AggregationSettings, Settings},
};

#[tokio::main]
Expand Down

0 comments on commit 0e8f26d

Please sign in to comment.