mirror of
https://github.com/SikongJueluo/Mini-Nav.git
synced 2026-03-12 12:25:32 +08:00
78 lines
2.1 KiB
Python
78 lines
2.1 KiB
Python
from typing import Literal, cast
|
|
|
|
import polars as pl
|
|
import torch
|
|
from datasets import Dataset, load_dataset
|
|
from torch import Tensor, nn
|
|
from torch.utils.data import DataLoader
|
|
from tqdm.auto import tqdm
|
|
from transformers import BitImageProcessorFast
|
|
from utils import get_device
|
|
|
|
|
|
def establish_database(
|
|
processor: BitImageProcessorFast,
|
|
model: nn.Module,
|
|
dataset: Dataset,
|
|
batch_size: int = 32,
|
|
) -> pl.DataFrame:
|
|
df = pl.DataFrame()
|
|
|
|
model.eval()
|
|
dataloader = DataLoader(
|
|
dataset.with_format("torch"), batch_size=batch_size, shuffle=True, num_workers=4
|
|
)
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(dataloader, desc="Establish Database"):
|
|
imgs = batch["img"]
|
|
labels = batch["label"]
|
|
|
|
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
|
|
|
outputs = cast(Tensor, model(inputs))
|
|
|
|
return df
|
|
|
|
|
|
def task_eval(
|
|
processor: BitImageProcessorFast,
|
|
model: nn.Module,
|
|
dataset: Literal["CIFAR-10", "CIFAR-100"],
|
|
top_k: int = 10,
|
|
batch_size: int = 32,
|
|
):
|
|
match dataset:
|
|
case "CIFAR-10":
|
|
train_dataset = load_dataset("uoft-cs/cifar10", split="train")
|
|
test_dataset = load_dataset("uoft-cs/cifar10", split="test")
|
|
case "CIFAR-100":
|
|
train_dataset = load_dataset("uoft-cs/cifar100", split="train")
|
|
test_dataset = load_dataset("uoft-cs/cifar100", split="test")
|
|
case _:
|
|
raise ValueError(
|
|
f"Unknown dataset: {dataset}. Only support: 'CIFAR-10', 'CIFAR-100'."
|
|
)
|
|
|
|
# Establish database
|
|
df = establish_database(processor, model, train_dataset, batch_size)
|
|
|
|
# Test
|
|
dataloader = DataLoader(
|
|
test_dataset.with_format("torch"),
|
|
batch_size=batch_size,
|
|
shuffle=True,
|
|
num_workers=4,
|
|
)
|
|
|
|
with torch.no_grad():
|
|
for batch in tqdm(dataloader, desc="Test Evaluation"):
|
|
imgs = batch["img"]
|
|
labels = batch["label"]
|
|
|
|
inputs = processor(imgs, return_tensors="pt").to(get_device())
|
|
|
|
outputs = cast(Tensor, model(inputs))
|
|
for vec in outputs:
|
|
pass
|