diff --git a/include/xwidgets/xbinary.hpp b/include/xwidgets/xbinary.hpp index 8325ac9..5f2d64e 100644 --- a/include/xwidgets/xbinary.hpp +++ b/include/xwidgets/xbinary.hpp @@ -23,11 +23,10 @@ namespace xw { using xjson_path_type = std::vector; - XWIDGETS_API void extract_buffer_paths( - const std::vector& to_check, + XWIDGETS_API void reorder_buffer_paths( + const std::vector& buffer_paths, const nl::json& patch, - const xeus::buffer_sequence& buffers, - nl::json& buffer_paths + std::vector& out ); XWIDGETS_API void insert_buffer_paths(nl::json& patch, const nl::json& buffer_paths); diff --git a/include/xwidgets/xcommon.hpp b/include/xwidgets/xcommon.hpp index eaa6fa6..5b89ebd 100644 --- a/include/xwidgets/xcommon.hpp +++ b/include/xwidgets/xcommon.hpp @@ -57,6 +57,8 @@ namespace xw xeus::xguid id() const noexcept; void display() const; + std::vector& buffer_paths(); + const std::vector& buffer_paths() const; protected: @@ -74,8 +76,6 @@ namespace xw const xeus::xcomm& comm() const; const xeus::xmessage*& hold(); const xeus::xmessage* const& hold() const; - std::vector& buffer_paths(); - const std::vector& buffer_paths() const; void open(nl::json&& patch, xeus::buffer_sequence&& buffers); void close(); diff --git a/include/xwidgets/xcontroller.hpp b/include/xwidgets/xcontroller.hpp index 36ba083..f0ef9d8 100644 --- a/include/xwidgets/xcontroller.hpp +++ b/include/xwidgets/xcontroller.hpp @@ -236,7 +236,12 @@ namespace xw inline xcontroller::xcontroller() : base_type() { - register_control_types(); + // Making a dummy static variable to only call the registration once. + static const auto initialized = []() + { + register_control_types(); + return true; + }(); set_defaults(); } diff --git a/include/xwidgets/xholder.hpp b/include/xwidgets/xholder.hpp index 68a6a40..7c996b6 100644 --- a/include/xwidgets/xholder.hpp +++ b/include/xwidgets/xholder.hpp @@ -14,10 +14,13 @@ #include #include -#include "nlohmann/json.hpp" -#include "xeus/xguid.hpp" -#include "xtl/xany.hpp" -#include "xtl/xclosure.hpp" +#include +#include +#include +#include +#include + +#include "xbinary.hpp" #include "xwidgets_config.hpp" namespace nl = nlohmann; @@ -68,6 +71,8 @@ namespace xw void display() const; xeus::xguid id() const; + void serialize_state(nl::json& state, xeus::buffer_sequence& buffers) const; + const std::vector& buffer_paths() const; xtl::any value() &; const xtl::any value() const&; @@ -131,6 +136,8 @@ namespace xw virtual void display() const = 0; virtual xeus::xguid id() const = 0; + virtual void serialize_state(nl::json& state, xeus::buffer_sequence& buffers) const = 0; + virtual const std::vector& buffer_paths() const = 0; virtual xtl::any value() & = 0; virtual const xtl::any value() const& = 0; @@ -162,31 +169,39 @@ namespace xw { } - virtual ~xholder_owning() - { - } + ~xholder_owning() override = default; - virtual base_type* clone() const override + base_type* clone() const override { return new xholder_owning(*this); } - virtual void display() const override + void display() const override { m_value.display(); } - virtual xeus::xguid id() const override + xeus::xguid id() const override { return m_value.id(); } - virtual xtl::any value() & override + void serialize_state(nl::json& state, xeus::buffer_sequence& buffers) const override + { + return m_value.serialize_state(state, buffers); + } + + const std::vector& buffer_paths() const override + { + return m_value.buffer_paths(); + } + + xtl::any value() & override { return xtl::closure(m_value); } - virtual const xtl::any value() const& override + const xtl::any value() const& override { return xtl::closure(m_value); } @@ -214,32 +229,42 @@ namespace xw { } - virtual ~xholder_weak() + ~xholder_weak() override { p_value = nullptr; } - virtual base_type* clone() const override + base_type* clone() const override { return new xholder_weak(*this); } - virtual void display() const override + void display() const override { p_value->display(); } - virtual xeus::xguid id() const override + xeus::xguid id() const override { return p_value->id(); } - virtual xtl::any value() & override + void serialize_state(nl::json& state, xeus::buffer_sequence& buffers) const override + { + return p_value->serialize_state(state, buffers); + } + + const std::vector& buffer_paths() const override + { + return p_value->buffer_paths(); + } + + xtl::any value() & override { return xtl::closure(*p_value); } - virtual const xtl::any value() const& override + const xtl::any value() const& override { return xtl::closure(*p_value); } @@ -274,29 +299,39 @@ namespace xw { } - virtual ~xholder_shared() = default; + ~xholder_shared() override = default; - virtual base_type* clone() const override + base_type* clone() const override { return new xholder_shared(*this); } - virtual void display() const override + void display() const override { p_value->display(); } - virtual xeus::xguid id() const override + xeus::xguid id() const override { return p_value->id(); } - virtual xtl::any value() & override + void serialize_state(nl::json& state, xeus::buffer_sequence& buffers) const override + { + return p_value->serialize_state(state, buffers); + } + + const std::vector& buffer_paths() const override + { + return p_value->buffer_paths(); + } + + xtl::any value() & override { return xtl::closure(*p_value); } - virtual const xtl::any value() const& override + const xtl::any value() const& override { return xtl::closure(*p_value); } diff --git a/include/xwidgets/xregistry.hpp b/include/xwidgets/xregistry.hpp index 5c22ac1..2bb1e79 100644 --- a/include/xwidgets/xregistry.hpp +++ b/include/xwidgets/xregistry.hpp @@ -30,6 +30,13 @@ namespace xw using holder_type = xholder; using storage_type = std::unordered_map; + using mapped_type = typename storage_type::mapped_type; + using const_iterator = typename storage_type::const_iterator; + + XWIDGETS_API const_iterator begin() const; + XWIDGETS_API const_iterator cbegin() const; + XWIDGETS_API const_iterator end() const; + XWIDGETS_API const_iterator cend() const; template void register_weak(xtransport* ptr); @@ -39,7 +46,7 @@ namespace xw XWIDGETS_API void unregister(xeus::xguid id); - XWIDGETS_API typename storage_type::mapped_type& find(xeus::xguid id); + XWIDGETS_API mapped_type& find(xeus::xguid id); private: diff --git a/src/xbinary.cpp b/src/xbinary.cpp index 377dcf0..0362759 100644 --- a/src/xbinary.cpp +++ b/src/xbinary.cpp @@ -99,23 +99,34 @@ namespace xw } } - void extract_buffer_paths( - const std::vector& to_check, + void reorder_buffer_paths( + const std::vector& buffer_paths, const nl::json& patch, - const xeus::buffer_sequence& buffers, - nl::json& buffer_paths + std::vector& out ) { - buffer_paths = nl::json(buffers.size(), nullptr); - for (const auto& path : to_check) + auto ensure_out_size = [&out](std::size_t size) + { + if (out.size() < size) + { + out.resize(size, nullptr); + } + }; + + ensure_out_size(buffer_paths.size()); + for (const auto& path : buffer_paths) { const nl::json* item = detail::get_buffers(patch, path); if (item != nullptr && item->is_string()) { - const std::string leaf = item->get(); + const auto& leaf = item->get(); if (is_buffer_reference(leaf)) { - buffer_paths[buffer_index(leaf)] = path; + auto const idx = buffer_index(leaf); + // Idx may be greater than to_check.size() when the buffers are used with + // multiple states + ensure_out_size(idx + 1); + out[idx] = path; } } } diff --git a/src/xcommon.cpp b/src/xcommon.cpp index 18b158d..7cc5e1e 100644 --- a/src/xcommon.cpp +++ b/src/xcommon.cpp @@ -154,8 +154,8 @@ namespace xw void xcommon::send_patch(nl::json&& patch, xeus::buffer_sequence&& buffers, const char* method) const { // extract buffer paths - auto paths = nl::json::array(); - extract_buffer_paths(buffer_paths(), patch, buffers, paths); + std::vector paths{}; + reorder_buffer_paths(buffer_paths(), patch, paths); // metadata nl::json metadata; @@ -174,8 +174,8 @@ namespace xw void xcommon::open(nl::json&& patch, xeus::buffer_sequence&& buffers) { // extract buffer paths - auto paths = nl::json::array(); - extract_buffer_paths(buffer_paths(), patch, buffers, paths); + std::vector paths{}; + reorder_buffer_paths(buffer_paths(), patch, paths); // metadata nl::json metadata; diff --git a/src/xholder.cpp b/src/xholder.cpp index 8e345b7..db7335a 100644 --- a/src/xholder.cpp +++ b/src/xholder.cpp @@ -64,6 +64,18 @@ namespace xw return p_holder->id(); } + void xholder::serialize_state(nl::json& state, xeus::buffer_sequence& buffers) const + { + check_holder(); + return p_holder->serialize_state(state, buffers); + } + + const std::vector& xholder::buffer_paths() const + { + check_holder(); + return p_holder->buffer_paths(); + } + xtl::any xholder::value() & { check_holder(); diff --git a/src/xholder_id.cpp b/src/xholder_id.cpp index 1660806..83a4728 100644 --- a/src/xholder_id.cpp +++ b/src/xholder_id.cpp @@ -32,35 +32,41 @@ namespace xw { } - virtual ~xholder_id() = default; + ~xholder_id() override = default; - virtual base_type* clone() const override + base_type* clone() const override { return new xholder_id(*this); } - virtual void display() const override + void display() const override { - auto& holder = get_transport_registry().find(m_id); - holder.display(); + return get_transport_registry().find(m_id).display(); } - virtual xeus::xguid id() const override + xeus::xguid id() const override { - auto& holder = get_transport_registry().find(m_id); - return holder.id(); + return get_transport_registry().find(m_id).id(); } - virtual xtl::any value() & override + void serialize_state(nl::json& state, xeus::buffer_sequence& buffers) const override { - auto& holder = get_transport_registry().find(m_id); - return holder.value(); + return get_transport_registry().find(m_id).serialize_state(state, buffers); } - virtual const xtl::any value() const& override + const std::vector& buffer_paths() const override { - const auto& holder = get_transport_registry().find(m_id); - return holder.value(); + return get_transport_registry().find(m_id).buffer_paths(); + } + + xtl::any value() & override + { + return get_transport_registry().find(m_id).value(); + } + + const xtl::any value() const& override + { + return get_transport_registry().find(m_id).value(); } private: diff --git a/src/xregistry.cpp b/src/xregistry.cpp index ae31745..7b01067 100644 --- a/src/xregistry.cpp +++ b/src/xregistry.cpp @@ -1,3 +1,11 @@ +/*************************************************************************** + * Copyright (c) 2022, QuantStack and XWidgets contributors * + * * + * Distributed under the terms of the BSD 3-Clause License. * + * * + * The full license is in the file LICENSE, distributed with this software. * + ****************************************************************************/ + #include "xwidgets/xregistry.hpp" namespace xw @@ -17,6 +25,26 @@ namespace xw return it->second; } + auto xregistry::begin() const -> const_iterator + { + return cbegin(); + } + + auto xregistry::cbegin() const -> const_iterator + { + return m_storage.cbegin(); + } + + auto xregistry::end() const -> const_iterator + { + return cend(); + } + + auto xregistry::cend() const -> const_iterator + { + return m_storage.cend(); + } + xregistry& get_transport_registry() { static xregistry instance; diff --git a/src/xtarget.cpp b/src/xtarget.cpp index 6a57a28..75b5d9a 100644 --- a/src/xtarget.cpp +++ b/src/xtarget.cpp @@ -8,14 +8,21 @@ #include "xtarget.hpp" +#include +#include #include #include +#include #include +#include #include #include +#include "xwidgets/xbinary.hpp" +#include "xwidgets/xcommon.hpp" #include "xwidgets/xfactory.hpp" +#include "xwidgets/xregistry.hpp" #include "xwidgets/xwidgets_config.hpp" namespace xw @@ -27,6 +34,12 @@ namespace xw return "jupyter.widget"; } + /** + * Check frontend widget version and instanciate widget. + * + * This callback function is called by Xeus when a comm channel is open by the frontend + * to create a companion widget in the kernel. + */ void xobject_comm_opened(xeus::xcomm&& comm, const xeus::xmessage& msg) { const nl::json& content = msg.content(); @@ -58,7 +71,9 @@ namespace xw void register_widget_target() { xeus::get_interpreter().comm_manager().register_comm_target( + /** The target name */ get_widget_target_name(), + /** Callback for comm opened by the frontend on this target, one per widget */ xobject_comm_opened ); } @@ -66,7 +81,141 @@ namespace xw xeus::xtarget* get_widget_target() { - register_widget_target(); + // Making a dummy static variable to only call the registration once. + static const auto initialized = []() + { + register_widget_target(); + return true; + }(); return ::xeus::get_interpreter().comm_manager().target(get_widget_target_name()); } + + namespace + { + template + std::vector + prepend_to_json_paths(std::vector paths, JsonPath const& prefix) + { + std::for_each( + paths.begin(), + paths.end(), + [&](xjson_path_type& p) + { + p.insert(p.begin(), prefix.begin(), prefix.end()); + } + ); + return paths; + } + + void + serialize_all_states(nl::json& states, xeus::buffer_sequence& buffers, std::vector& buffer_paths) + { + for (auto const& id_and_widget : get_transport_registry()) + { + auto const& holder = id_and_widget.second; + // This is not what the protocol states (?) but what IPyWidgets does + // https://github.com/jupyter-widgets/ipywidgets/issues/3685 + nl::json stateish = nl::json::object(); + holder.serialize_state(stateish["state"], buffers); + stateish["model_name"] = stateish["state"]["_model_name"]; + stateish["model_module"] = stateish["state"]["_model_module"]; + stateish["model_module_version"] = stateish["state"]["_model_module_version"]; + states[holder.id()] = std::move(stateish); + // Add buffer paths, but add the xguid/state prefix of multi-state schema + reorder_buffer_paths( + prepend_to_json_paths(holder.buffer_paths(), std::array{holder.id(), "state"}), + states, + buffer_paths + ); + } + } + + /** + * Register the ``on_message`` callback on the comm to get all widgets states. + * + * This callback function is called by Xeus when a comm channel is open by the frontend + * on the ``jupyter.widget.control`` target. + * This happens when the frontend needs to get the state of all widgets and no immediate + * action is required. + * Following the opening of the comm, the frontend sends a message with a + * ``request_states`` method, to which the kernel replies with the state of all widgets. + * + * After the frontend recieves the ``update_states`` response it closes the comm. + * Additional (and simulataneous) comms can be opened for fetching states. + */ + void control_comm_opened(xeus::xcomm&& comm, const xeus::xmessage&) + { + // This is a very simple registry for comm since their lifetime is managed by the + // frontend + static std::unordered_map comm_registry{}; + + auto iter_inserted = comm_registry.emplace(std::make_pair(comm.id(), std::move(comm))); + // Should really be inserted, but in case it is not, we let the comm gets destroyed and closed + assert(iter_inserted.second); + if (!iter_inserted.second) + { + return; + } + + auto& registered_comm = iter_inserted.first->second; + + registered_comm.on_message( + [&](const ::xeus::xmessage& msg) + { + auto const& method = msg.content()["data"]["method"]; + + nl::json states = nl::json::object(); + xeus::buffer_sequence buffers{}; + std::vector buffer_paths{}; + serialize_all_states(states, buffers, buffer_paths); + + nl::json metadata = {{"version", XWIDGETS_PROTOCOL_VERSION}}; + + nl::json data = nl::json::object(); + data["method"] = "update_states"; + data["states"] = std::move(states); + data["buffer_paths"] = std::move(buffer_paths); + + registered_comm.send(std::move(metadata), std::move(data), std::move(buffers)); + } + ); + + registered_comm.on_close( + [&](const ::xeus::xmessage&) + { + // This is not trivial. The comm is destructed from within one of its method. + // This works because no other instruction are executed by Xeus afterwards. + comm_registry.erase(registered_comm.id()); + } + ); + } + + const char* get_control_target_name() + { + return "jupyter.widget.control"; + } + + /** + * Register the ``jupyter.widget.control`` Xeus target. + * + * This target is used by the frontend to get the state of all widget in a single message + * (_e.g._ when restarting). + */ + void register_control_target() + { + xeus::get_interpreter().comm_manager().register_comm_target( + /** The target name */ + get_control_target_name(), + /** Callback for comm opened by the frontend on this target */ + control_comm_opened + ); + } + + // Making a dummy static variable to call the registration at load time. + static const auto initialized = []() + { + register_control_target(); + return true; + }(); + } } diff --git a/src/xtarget.hpp b/src/xtarget.hpp index 2dd96c2..c3fa5aa 100644 --- a/src/xtarget.hpp +++ b/src/xtarget.hpp @@ -13,6 +13,13 @@ namespace xw { + /** + * Return the ``jupyter.widget`` Xeus target. + * + * This target is used by the comms of the widgets to synchronize state between the frontend + * (_e.g._ jupyterlab) and the backend (kernel). + * This function will register the target with Xeus upon first call. + */ xeus::xtarget* get_widget_target(); }