1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
| import os import cv2 import numpy as np from sklearn.cluster import KMeans from scipy.spatial import distance from tqdm import tqdm from time import time import pickle
os.environ["LOKY_MAX_CPU_COUNT"] = "4"
image_paths = [] sift_res = {} k_means = None k = 50 histograms = {} sift = cv2.SIFT_create()
def get_images(image_folder): global image_paths for root, _, files in os.walk(image_folder): for filename in files: if filename.lower().endswith(".jpg"): image_paths.append(os.path.join(root, filename))
def extract_sift_features(cache="cache/BOF_SIFT_desc.pkl"): print("1.提取SIFT特征") global sift_res, image_paths if os.path.exists(cache): with open(cache, "rb") as f: print(f"从缓存中加载SIFT特征: {cache}") sift_res = pickle.load(f) return for img_path in tqdm(image_paths): img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) _, desc = sift.detectAndCompute(img, None) sift_res[img_path] = desc
with open(cache, "wb") as f: print(f"保存SIFT特征到缓存: {cache}") pickle.dump(sift_res, f)
def create_vocabulary(cache="cache/BOF_kmeans.pkl"): print("2.创建视觉词典") global k_means, k if os.path.exists(cache): with open(cache, "rb") as f: print(f"从缓存中加载K-Means模型: {cache}") k_means=pickle.load(f) return
st = time() descriptors_list = list(sift_res.values()) all_descriptors = np.vstack(descriptors_list) k_means = KMeans(n_clusters=k, random_state=0, n_init=10) k_means.fit(all_descriptors) print(f"K-Means聚类耗时: {time()-st:.2f}s")
with open(cache, "wb") as f: print(f"保存K-Means模型到缓存: {cache}") pickle.dump(k_means, f)
def compute_histogram(img_path): if img_path in sift_res: descriptors = sift_res[img_path] else: img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) _, descriptors = sift.detectAndCompute(img, None) words = k_means.predict(descriptors) hist, _ = np.histogram(words, bins=np.arange(k+1)) hist = hist / np.linalg.norm(hist) return hist
def compute_histograms(cache="cache/BOF_histogram.pkl"): print("3.计算数据库的BoF直方图") global image_paths, histograms, k_means
if os.path.exists(cache): with open(cache, "rb") as f: print(f"从缓存中加载数据库直方图: {cache}") histograms=pickle.load(f)
for img_path in tqdm(image_paths): histograms[img_path] = compute_histogram(img_path) with open(cache, "wb") as f: print(f"保存数据库直方图到缓存: {cache}") pickle.dump(histograms, f)
def cos_sim(a, b): return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
def get_topk(query_hist, topk=10): print("4.直方图匹配") global histograms similarities = [] for img_path, hist in tqdm(histograms.items()): dist = cos_sim(query_hist, hist) similarities.append((img_path, dist)) similarities.sort(key=lambda x: x[1], reverse=True) return similarities[:topk]
image_folder = "image" get_images(image_folder) extract_sift_features() create_vocabulary() compute_histograms()
test_img = "query/A0C573_20151103073308_3029240562.jpg" test_hist = compute_histogram(test_img) topk=get_topk(test_hist, topk=10) print("Top 10 相似图片:") for i, (filepath, sim) in enumerate(topk): print(f"{i+1}. {filepath} - 相似度: {sim:.4f}")
|