Files
Mini-Nav/mini-nav/benchmarks/task_eval.py

78 lines
2.1 KiB
Python

from typing import Literal, cast
import polars as pl
import torch
from datasets import Dataset, load_dataset
from torch import Tensor, nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import BitImageProcessorFast
from utils import get_device
def establish_database(
processor: BitImageProcessorFast,
model: nn.Module,
dataset: Dataset,
batch_size: int = 32,
) -> pl.DataFrame:
df = pl.DataFrame()
model.eval()
dataloader = DataLoader(
dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4
)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Establish Database"):
imgs = batch["img"]
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt").to(get_device())
outputs = cast(Tensor, model(inputs))
return df
def task_eval(
processor: BitImageProcessorFast,
model: nn.Module,
dataset: Literal["CIFAR-10", "CIFAR-100"],
top_k: int = 10,
batch_size: int = 32,
):
match dataset:
case "CIFAR-10":
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
case "CIFAR-100":
train_dataset = load_dataset("uoft-cs/cifar100", split="train")
test_dataset = load_dataset("uoft-cs/cifar100", split="test")
case _:
raise ValueError(
f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'."
)
# Establish database
df = establish_database(processor, model, train_dataset, batch_size)
# Test
dataloader = DataLoader(
test_dataset.with_format("torch"),
batch_size=batch_size,
shuffle=True,
num_workers=4,
)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Test Evaluation"):
imgs = batch["img"]
labels = batch["label"]
inputs = processor(imgs, return_tensors="pt").to(get_device())
outputs = cast(Tensor, model(inputs))
for vec in outputs:
pass