Skip to content

Commit

Permalink
feat: add MapSerializer
Browse files Browse the repository at this point in the history
  • Loading branch information
bsorrentino committed Aug 26, 2024
1 parent 40a910f commit 1407b41
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,74 +1,22 @@
package org.bsc.langgraph4j.serializer;

import lombok.extern.log4j.Log4j;
import lombok.extern.slf4j.Slf4j;
import org.bsc.langgraph4j.state.AgentState;

import java.io.*;
import java.util.HashMap;
import java.util.Map;

@Slf4j
public class AgentStateSerializer implements Serializer<AgentState> {
public static final AgentStateSerializer INSTANCE = new AgentStateSerializer();
private AgentStateSerializer() {}

@Override
public void write(AgentState object, ObjectOutput out) throws IOException {
try( ByteArrayOutputStream baos = new ByteArrayOutputStream() ) {
int actuoalSize = 0;

final ObjectOutputStream tupleStream = new ObjectOutputStream( baos );
for( Map.Entry<String,Object> e : object.data().entrySet() ) {
try {
tupleStream.writeUTF(e.getKey());
tupleStream.writeObject(e.getValue());
++actuoalSize;
} catch (IOException ex) {
log.error( "Error writing state key '{}' - {}", e.getKey(), ex.getMessage() );
throw ex;
}
}

out.writeInt( object.data().size() );
out.writeInt( actuoalSize ); // actual size
byte[] data = baos.toByteArray();
out.writeInt( data.length );
out.write( data );

}

MapSerializer.INSTANCE.write( object.data(), out );
}

@Override
public AgentState read(ObjectInput in) throws IOException, ClassNotFoundException {
Map<String, Object> data = new HashMap<>();

int expectedSize = in.readInt();
int actualSize = in.readInt();
if( expectedSize > 0 && actualSize > 0 ) {

if( expectedSize != actualSize ) {
final String message = String.format( "Deserialize State: Expected size %d and actual size %d do not match!", expectedSize, actualSize ) ;
log.error( message ) ;
throw new IOException( message ) ;
}

int byteLen = in.readInt();
byte[] bytes = new byte[byteLen];
in.readFully(bytes);

try( ByteArrayInputStream bais = new ByteArrayInputStream( bytes ) ) {
ObjectInputStream ois = new ObjectInputStream( bais );

for( int i = 0; i < actualSize; i++ ) {
String key = ois.readUTF();
Object value = ois.readObject();
data.put(key, value);
}
}

}
Map<String, Object> data = MapSerializer.INSTANCE.read( in );
return new AgentState(data);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@ private CheckpointSerializer() {}

public void write( Checkpoint object, ObjectOutput out) throws IOException {
out.writeUTF(object.getId());
Checkpoint.Value value = object.getValue();
AgentStateSerializer.INSTANCE.write( value.getState(), out );
out.writeUTF( value.getNodeId() );
AgentStateSerializer.INSTANCE.write( object.getState(), out );
out.writeUTF( object.getNodeId() );
out.writeUTF( object.getNextNodeId() );
}

public Checkpoint read(ObjectInput in) throws IOException, ClassNotFoundException {
String id = in.readUTF();
AgentState state = AgentStateSerializer.INSTANCE.read( in );
String nodeId = in.readUTF();
Checkpoint.Value value = Checkpoint.Value.of( state, nodeId );
return new Checkpoint(id, value);
return Checkpoint.builder()
.id(in.readUTF())
.state(AgentStateSerializer.INSTANCE.read( in ))
.nodeId(in.readUTF())
.nextNodeId(in.readUTF())
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import java.util.Map;

@Slf4j
public class StateSerializer implements Serializer<Map<String,Object>> {
public static final StateSerializer INSTANCE = new StateSerializer();
private StateSerializer() {}
public class MapSerializer implements Serializer<Map<String,Object>> {
public static final MapSerializer INSTANCE = new MapSerializer();
private MapSerializer() {}

@Override
public void write(Map<String,Object> object, ObjectOutput out) throws IOException {
Expand All @@ -23,7 +23,7 @@ public void write(Map<String,Object> object, ObjectOutput out) throws IOExceptio
tupleStream.writeObject(e.getValue());
++actualSize;
} catch (IOException ex) {
log.error( "Error writing state key '{}' - {}", e.getKey(), ex.getMessage() );
log.error( "Error writing map key '{}' - {}", e.getKey(), ex.getMessage() );
throw ex;
}
}
Expand All @@ -47,7 +47,7 @@ public Map<String, Object> read(ObjectInput in) throws IOException, ClassNotFoun
if( expectedSize > 0 && actualSize > 0 ) {

if( expectedSize != actualSize ) {
final String message = String.format( "Deserialize State: Expected size %d and actual size %d do not match!", expectedSize, actualSize ) ;
final String message = String.format( "Deserialize map: Expected size %d and actual size %d do not match!", expectedSize, actualSize ) ;
log.error( message ) ;
throw new IOException( message ) ;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package org.bsc.langgraph4j.serializer;

import org.bsc.langgraph4j.checkpoint.Checkpoint;

import java.io.*;
import java.util.Objects;
import java.util.Optional;
Expand Down Expand Up @@ -51,4 +49,5 @@ default T cloneObject(T object) throws IOException, ClassNotFoundException {
Objects.requireNonNull( object, "object cannot be null" );
return readObject(writeObject(object));
}

}

0 comments on commit 1407b41

Please sign in to comment.