import torch
from torchvision import models,transforms
from torch.utils.data import Dataset , DataLoader
import os
import pickle
from PIL import Image
from tqdm import tqdm
from pymilvus import (
FieldSchema,
DataType,
db,
connections,
CollectionSchema,
Collection
)
import time
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image_dir = "./flower_data/train"
image_dirs = [f"{p}/{n}" for p , n in zip([image_dir] * 102 , os.listdir(image_dir))]
image_paths = []
for dir in image_dirs:
    names = os.listdir(dir)
    for name in names:
        image_paths.append(os.path.join(dir,name))
image_paths
image_dirs
with open("image_paths.pkl" , "wb" ) as fw:
    pickle.dump(image_paths, fw)
class ImageDataset(Dataset):
    def __init__(self , transform =None):
        super().__init__()
        self.transform = transform
        with open("./image_paths.pkl", "rb") as fr:
            self.data_paths = pickle.load(fr)
            
        self.data = []
        
        for image_path in self.data_paths:
            img = Image.open(image_path)
            if img.mode == "RGB":
                self.data.append(image_path)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index): 
        image_path = self.data[index]
        img = Image.open(image_path)
        
        if self.transform:
            img = self.transform(img)
        
        dict_data = {
            "idx" : index,
            "image_path" : image_path,
            "img" : img
        }
        return dict_data
valid_dataset = ImageDataset(transform=transform)
len(valid_dataset)
valid_dataloader = DataLoader(valid_dataset , batch_size=64, shuffle=False)
def load_model():
    model = models.resnet18(pretrained = True)
    model.to(device)
    model.eval()
    return model
model = load_model()
model
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)
def feature_extract(model, x):
    x = model.conv1(x)
    x = model.bn1(x)
    x = model.relu(x)
    x = model.maxpool(x)
    x = model.layer1(x)
    x = model.layer2(x)
    x = model.layer3(x)
    x = model.layer4(x)
    x = model.avgpool(x)
    x = torch.flatten(x, 1)
    return x
feature_list = []
feature_index_list = []
feature_image_path_list = []
for idx , batch in enumerate(tqdm(valid_dataloader)):
    imgs = batch["img"]
    indexs = batch["idx"]
    image_paths = batch["image_path"]
    img = imgs.to(device)
    feature = feature_extract(model, img)
    feature = feature.data.cpu().numpy()
    feature_list.extend(feature)
    feature_index_list.extend(indexs)
    feature_image_path_list.extend(image_paths)
entities = [
    feature_image_path_list,
    feature_list 
]
len(feature_list)
entities[0]
fields = [
    FieldSchema(name="image_path" ,dtype=DataType.VARCHAR, description="图片路径", max_length = 512 , is_primary=True, auto_id=False),
    FieldSchema(name="embeddings" , dtype=DataType.FLOAT_VECTOR,description="向量表示图片" , is_primary=False,dim=512)
]
schema = CollectionSchema(fields,description="用于图生图的表")
connections.connect("power_image_search",host="ljxwtl.cn",port=19530,db_name="power_image_search")
table = Collection("image_to_image", schema=schema,consistency_level="Strong",using="power_image_search")
for idx , image_path in enumerate(feature_image_path_list):
    entity = [
        [feature_image_path_list[idx]],
        [feature_list[idx]]
    ]
    table.insert(entity)
table.flush()
table.num_entities

6552

index = {
    "index_type": "IVF_FLAT",
    "metric_type": "L2",
    "params": {"nlist": 128},
}
table.create_index("embeddings",index_params=index)
table.load()
vectors_to_search = entities[-1][1:2]
search_params = {
    "metric_type": "L2",
    "params": {"nprobe": 10},
}
start_time = time.time()
result = table.search(vectors_to_search, "embeddings", search_params, limit=5, output_fields=["image_path"])
end_time = time.time()
for hits in result:
    for hit in hits:
        print(f"hit: {hit}, image_path field: {hit.entity.get('image_path')}")

深度学习之使用Milvus向量数据库实战图搜图-LMLPHP

 

img_data = plt.imread(entities[0][1])
plt.imshow(img_data)
plt.show()
img_data = plt.imread("./flower_data/train/1\\image_06766.jpg")
plt.imshow(img_data)
plt.show()
11-20 14:17