Skip to content

Commit

Permalink
cli update
Browse files Browse the repository at this point in the history
updates cli calls to exit more gracefully, and added coverage tests for cli
  • Loading branch information
LMBooth committed Oct 25, 2023
1 parent 745218c commit f6e4241
Show file tree
Hide file tree
Showing 6 changed files with 411 additions and 217 deletions.
75 changes: 58 additions & 17 deletions Tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,60 @@
#import time
#from pybci.CliTests.testSimple import main as mainSimple
#from pybci.CliTests.testSklearn import main as mainSklearn
#from pybci.CliTests.testPyTorch import main as mainPyTorch
#from pybci.CliTests.testTensorflow import main as mainTensorflow

import threading
from pybci.CliTests.testSimple import main as mainSimple
from pybci.CliTests.testSklearn import main as mainSklearn
from pybci.CliTests.testPyTorch import main as mainPyTorch
from pybci.CliTests.testTensorflow import main as mainTensorflow
from unittest.mock import patch

# Example usage
def test_cli():
#mainSimple(min_epochs_train=1, min_epochs_test=2, timeout=10)
#time.sleep(15)
#m#ainSklearn(min_epochs_train=1, min_epochs_test=2, timeout=10)
#time.sleep(15)
#mainPyTorch(min_epochs_train=1, min_epochs_test=2, timeout=10)
#time.sleep(15)
#mainTensorflow(min_epochs_train=1, min_epochs_test=2, timeout=10)
#time.sleep(15)
assert True

#def test_cli():
def test_cli_simple_timeout():
with patch('builtins.input', return_value='stop'):
timeout = 30 # timeout in seconds
my_bci_wrapper = None

def run_main():
nonlocal my_bci_wrapper
my_bci_wrapper = mainSimple(createPseudoDevice=True, min_epochs_train=1, min_epochs_test=2, timeout=timeout)

main_thread = threading.Thread(target=run_main)
main_thread.start()
main_thread.join()

def test_cli_sklearn_timeout():
with patch('builtins.input', return_value='stop'):
timeout = 30 # timeout in seconds
my_bci_wrapper = None

def run_main():
nonlocal my_bci_wrapper
my_bci_wrapper = mainSklearn(createPseudoDevice=True, min_epochs_train=1, min_epochs_test=2, timeout=timeout)

main_thread = threading.Thread(target=run_main)
main_thread.start()
main_thread.join()

def test_cli_pytorch_timeout():
with patch('builtins.input', return_value='stop'):
timeout = 30 # timeout in seconds
my_bci_wrapper = None

def run_main():
nonlocal my_bci_wrapper
my_bci_wrapper = mainPyTorch(createPseudoDevice=True, min_epochs_train=1, min_epochs_test=2, timeout=timeout)

main_thread = threading.Thread(target=run_main)
main_thread.start()
main_thread.join()

def test_cli_tensorflow_timeout():
with patch('builtins.input', return_value='stop'):
timeout = 30 # timeout in seconds
my_bci_wrapper = None

def run_main():
nonlocal my_bci_wrapper
my_bci_wrapper = mainTensorflow(createPseudoDevice=True, min_epochs_train=1, min_epochs_test=2, timeout=timeout)

main_thread = threading.Thread(target=run_main)
main_thread.start()
main_thread.join()
215 changes: 131 additions & 84 deletions pybci/CliTests/testPyTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,104 +6,151 @@
from torch import nn
import threading

stop_signal = threading.Event() # Global event to control the main loop

def main(createPseudoDevice=True, min_epochs_train=4, min_epochs_test=10, num_chs = 8, num_feats = 2, num_classes = 4, timeout=None):
if createPseudoDevice:
num_chs = 8 # 8 channels are created in the PseudoLSLGenerator
num_feats = 2 # default is mean freq and rms to keep it simple
num_classes = 4 # number of different triggers (can include baseline) sent, defines if we use softmax of binary

class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.bn1 = nn.BatchNorm1d(hidden_size)
self.relu = nn.ReLU(inplace=True) # In-place operation
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.bn2 = nn.BatchNorm1d(hidden_size)
self.fc3 = nn.Linear(hidden_size, num_classes)
global num_chs_g, num_feats_g, num_classes_g

class SimpleNN(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.bn1 = nn.BatchNorm1d(hidden_size)
self.relu = nn.ReLU(inplace=True) # In-place operation
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.bn2 = nn.BatchNorm1d(hidden_size)
self.fc3 = nn.Linear(hidden_size, num_classes)

def forward(self, x):
out = self.fc1(x)
if out.shape[0] > 1: # Skip BatchNorm if batch size is 1
out = self.bn1(out)
out = self.relu(out)
out = self.fc2(out)
if out.shape[0] > 1: # Skip BatchNorm if batch size is 1
out = self.bn2(out)
out = self.relu(out)
out = self.fc3(out)
return out

def PyTorchModel(x_train, x_test, y_train, y_test):
input_size = num_feats_g*num_chs_g # num of channels multipled by number of default features (rms and mean freq)
hidden_size = 100
#num_classes = num_classes # default in pseudodevice
model = SimpleNN(input_size, hidden_size, num_classes_g)
model.train()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 10
train_data = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train).long())
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True, drop_last=True) # Drop last incomplete batch
for epoch in range(epochs):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
model.eval()
accuracy = 0
with torch.no_grad():
test_outputs = model(torch.Tensor(x_test))
_, predicted = torch.max(test_outputs.data, 1)
correct = (predicted == torch.Tensor(y_test).long()).sum().item()
accuracy = correct / len(y_test)
return accuracy, model


def command_listener():
while not stop_signal.is_set():
command = input("PyBCI: [CLI] - Enter 'stop' to terminate\n")
if command == 'stop':
stop_signal.set()
break


class CLI_testPytorchWrapper:
def __init__(self, createPseudoDevice, min_epochs_train, min_epochs_test,num_chs, num_feats, num_classes, timeout):
if createPseudoDevice:
self.num_chs = 8 # 8 channels are created in the PseudoLSLGenerator
self.num_feats = 2 # default is mean freq and rms to keep it simple
self.num_classes = 4 # number of different triggers (can include baseline) sent, defines if we use softmax of binary

self.createPseudoDevice = createPseudoDevice
self.timeout = timeout
self.min_epochs_train = min_epochs_train
self.min_epochs_test = min_epochs_test
self.accuracy = 0
self.currentMarkers = {}
if self.min_epochs_test <= self.min_epochs_train:
self.min_epochs_test = self.min_epochs_train+1

def forward(self, x):
out = self.fc1(x)
if out.shape[0] > 1: # Skip BatchNorm if batch size is 1
out = self.bn1(out)
out = self.relu(out)
out = self.fc2(out)
if out.shape[0] > 1: # Skip BatchNorm if batch size is 1
out = self.bn2(out)
out = self.relu(out)
out = self.fc3(out)
return out
def PyTorchModel(x_train, x_test, y_train, y_test):
input_size = num_feats*num_chs # num of channels multipled by number of default features (rms and mean freq)
hidden_size = 100
#num_classes = num_classes # default in pseudodevice
model = SimpleNN(input_size, hidden_size, num_classes)
model.train()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 10
train_data = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train).long())
train_loader = DataLoader(dataset=train_data, batch_size=32, shuffle=True, drop_last=True) # Drop last incomplete batch
for epoch in range(epochs):
for inputs, labels in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
model.eval()
accuracy = 0
with torch.no_grad():
test_outputs = model(torch.Tensor(x_test))
_, predicted = torch.max(test_outputs.data, 1)
correct = (predicted == torch.Tensor(y_test).long()).sum().item()
accuracy = correct / len(y_test)
return accuracy, model

def loop(bci):
while not bci.connected: # check to see if lsl marker and datastream are available
bci.Connect()
self.bci = PyBCI(minimumEpochsRequired = min_epochs_train, createPseudoDevice=createPseudoDevice, torchModel = PyTorchModel)
#self.bci = PyBCI(minimumEpochsRequired = self.min_epochs_train, createPseudoDevice=self.createPseudoDevice)
main_thread = threading.Thread(target=self.loop)
main_thread.start()
if self.timeout:
print("PyBCI: [CLI] - starting timeout thread")
self.timeout_thread = threading.Thread(target=self.stop_after_timeout)
self.timeout_thread.start()
main_thread.join()
if timeout is not None:
self.timeout_thread.join()


def loop(self):
while not self.bci.connected: # check to see if lsl marker and datastream are available
self.bci.Connect()
time.sleep(1)
bci.TrainMode() # now both marker and datastreams available start training on received epochs
accuracy = 0
self.bci.TrainMode() # now both marker and datastreams available start training on received epochs
self.accuracy = 0
test = False
try:
while(True):
while not stop_signal.is_set(): # Add the check here
if test is False:
currentMarkers = bci.ReceivedMarkerCount() # check to see how many received epochs, if markers sent to close together will be ignored till done processing
self.currentMarkers = self.bci.ReceivedMarkerCount() # check to see how many received epochs, if markers sent to close together will be ignored till done processing
time.sleep(0.5) # wait for marker updates
print("Markers received: " + str(currentMarkers) +" Accuracy: " + str(round(accuracy,2)), end=" \r")
if len(currentMarkers) > 1: # check there is more then one marker type received
if min([currentMarkers[key][1] for key in currentMarkers]) > bci.minimumEpochsRequired:
classInfo = bci.CurrentClassifierInfo() # hangs if called too early
accuracy = classInfo["accuracy"]
if min([currentMarkers[key][1] for key in currentMarkers]) > min_epochs_test:
bci.TestMode()
break
print("Markers received: " + str(self.currentMarkers) +" Accuracy: " + str(round(self.accuracy,2)), end=" \r")
if len(self.currentMarkers) > 1: # check there is more then one marker type received
if min([self.currentMarkers[key][1] for key in self.currentMarkers]) > self.bci.minimumEpochsRequired:
classInfo = self.bci.CurrentClassifierInfo() # hangs if called too early
self.accuracy = classInfo["accuracy"]
if min([self.currentMarkers[key][1] for key in self.currentMarkers]) > self.min_epochs_test:
self.bci.TestMode()
test = True
else:
markerGuess = bci.CurrentClassifierMarkerGuess() # when in test mode only y_pred returned
guess = [key for key, value in currentMarkers.items() if value[0] == markerGuess]
markerGuess = self.bci.CurrentClassifierMarkerGuess() # when in test mode only y_pred returned
guess = [key for key, value in self.currentMarkers.items() if value[0] == markerGuess]
print("Current marker estimation: " + str(guess), end=" \r")
time.sleep(0.2)

return None
self.bci.StopThreads()
except KeyboardInterrupt: # allow user to break while loop
print("\nLoop interrupted by user.")

def stop_after_timeout(bci):
time.sleep(timeout)
def stop_after_timeout(self):
time.sleep(self.timeout)
stop_signal.set()
print("\nTimeout reached. Stopping threads.")
bci.StopThreads()

bci = PyBCI(minimumEpochsRequired = min_epochs_train, createPseudoDevice=createPseudoDevice, torchModel = PyTorchModel)
main_thread = threading.Thread(target=loop, args=(bci,))
main_thread.start()
if timeout:
timeout_thread = threading.Thread(target=stop_after_timeout, args=(bci,))
timeout_thread.start()
timeout_thread.join()
main_thread.join()

# Add these methods in CLI_testSimpleWrapper class
def get_accuracy(self):
return self.accuracy

def get_current_markers(self):
return self.currentMarkers

def main(createPseudoDevice=True, min_epochs_train=4, min_epochs_test=10, num_chs = 8, num_feats = 2, num_classes = 4, timeout=None):
global num_chs_g, num_feats_g, num_classes_g
num_chs_g = num_chs
num_feats_g = num_feats
num_classes_g = num_classes
command_thread = threading.Thread(target=command_listener)
command_thread.daemon = True
command_thread.start()

my_bci_wrapper = CLI_testPytorchWrapper(createPseudoDevice, min_epochs_train, min_epochs_test,num_chs, num_feats, num_classes,timeout)
command_thread.join()
return my_bci_wrapper # Return this instance


if __name__ == '__main__':
parser = argparse.ArgumentParser(description="PyTorch neural network is used for model and pseudodevice generates 8 channels of 3 marker types and baseline. Similar to the testPytorch.py in the examples folder.")
Expand Down
Loading

0 comments on commit f6e4241

Please sign in to comment.