Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comprehension Question - Negating Input-Images #20

Open
42elenz opened this issue Mar 4, 2024 · 0 comments
Open

Comprehension Question - Negating Input-Images #20

42elenz opened this issue Mar 4, 2024 · 0 comments

Comments

@42elenz
Copy link

42elenz commented Mar 4, 2024

Thank you for you awesome code and work.
I have a question:
In your train_batch function in the run_train.py you are negating the images with -images in the forward pass:
logits_per_image, logits_per_text = model(-images, texts)

Why do you do so?

Also in the train function you wanted to get the "highest_val_auc" which you never used. How would you define the AUC in such a scenario?

`def train(model, loader, device, criterion, optimizer, config):
model_save_dir = os.path.join(config.save_dir, config.model_name)
if not os.path.exists(model_save_dir):
# Create a new folder if not exists
os.makedirs(model_save_dir)

# Run training
total_batches = len(loader) * config.epochs
example_ct = 0  # number of examples seen
batch_ct = 0
report_freq = config.log_interval
highest_val_auc = 0 # save highest mean auc

for epoch in range(config.epochs):
    running_loss = 0.0 # running loss over batch
    for data in tqdm(loader):
        # get the images
        images = data['img']

        texts = data['txt']
        texts = preprocess_text(texts, model) 
        
        # perform step for a single batch
        loss = train_batch(images, texts, model, device, criterion, optimizer)
        example_ct +=  len(images)
        batch_ct += 1
        running_loss += loss.item()

        # Report metrics every `report_freq` batch
        if (batch_ct % report_freq) == 0:
            train_log(running_loss / report_freq, example_ct, epoch)
            running_loss = 0.0
        
        if (batch_ct % config.save_interval) == 0: 
            model_path = os.path.join(model_save_dir, "checkpoint_{batch_ct}.pt".format(
                batch_ct=str(batch_ct), 
            ))
            print("Saved checkpoint to: ", model_path)
            save(model, model_path)`

I hope you can help me with my questions!

Thanks again for making your code publicly avaiable and greetings from Germany

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant