leedoming commited on
Commit
b404c7a
1 Parent(s): ed430f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -46
app.py CHANGED
@@ -11,6 +11,8 @@ import cv2
11
  from inference_sdk import InferenceHTTPClient
12
  import matplotlib.pyplot as plt
13
  import base64
 
 
14
 
15
  # Load model and tokenizer
16
  @st.cache_resource
@@ -30,88 +32,77 @@ def load_data():
30
 
31
  data = load_data()
32
 
33
- # Helper functions
34
- @st.cache_data
 
 
 
 
35
  def download_and_process_image(image_url):
36
  try:
37
  response = requests.get(image_url)
38
- response.raise_for_status() # Raises an HTTPError for bad responses
39
  image = Image.open(BytesIO(response.content))
40
-
41
- # Convert image to RGB mode if it's in RGBA mode
42
  if image.mode == 'RGBA':
43
  image = image.convert('RGB')
44
-
45
  return image
46
- except requests.RequestException as e:
47
- st.error(f"Error downloading image: {e}")
48
- return None
49
  except Exception as e:
50
- st.error(f"Error processing image: {e}")
51
  return None
52
 
53
- def get_image_embedding(image):
54
- image_tensor = preprocess_val(image).unsqueeze(0).to(device)
55
- with torch.no_grad():
56
- image_features = model.encode_image(image_tensor)
57
- image_features /= image_features.norm(dim=-1, keepdim=True)
58
- return image_features.cpu().numpy()
59
-
60
- def setup_roboflow_client(api_key):
61
- return InferenceHTTPClient(
62
- api_url="https://outline.roboflow.com",
63
- api_key=api_key
64
- )
65
-
66
- def segment_image(image_path, client):
67
  try:
68
- # 이미지 파일 읽기
69
  with open(image_path, "rb") as image_file:
70
  image_data = image_file.read()
71
 
72
- # 이미지를 base64로 인코딩
73
  encoded_image = base64.b64encode(image_data).decode('utf-8')
74
 
75
- # 원본 이미지 로드
76
  image = cv2.imread(image_path)
77
  image = cv2.resize(image, (800, 600))
78
  mask = np.zeros(image.shape, dtype=np.uint8)
79
 
80
- # Roboflow API 호출
81
  results = client.infer(encoded_image, model_id="closet/1")
82
 
83
- # 결과가 이미 딕셔너리인 경우 JSON 파싱 단계 제거
84
  if isinstance(results, dict):
85
  predictions = results.get('predictions', [])
86
  else:
87
- # 문자열인 경우에만 JSON 파싱
88
  predictions = json.loads(results).get('predictions', [])
89
 
 
90
  if predictions:
91
  for prediction in predictions:
92
  points = prediction['points']
93
  pts = np.array([[p['x'], p['y']] for p in points], np.int32)
94
- scale_x = image.shape[1] / results['image']['width']
95
- scale_y = image.shape[0] / results['image']['height']
96
  pts = pts * [scale_x, scale_y]
97
  pts = pts.astype(np.int32)
98
  pts = pts.reshape((-1, 1, 2))
99
- cv2.fillPoly(mask, [pts], color=(255, 255, 255)) # White mask
 
 
 
 
100
 
101
  segmented_image = cv2.bitwise_and(image, mask)
102
  else:
103
  st.warning("No predictions found in the image. Returning original image.")
104
  segmented_image = image
105
 
106
- return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB))
107
  except Exception as e:
108
  st.error(f"Error in segmentation: {str(e)}")
109
- # 원본 이미지를 다시 읽어 반환
110
- return Image.open(image_path)
 
 
 
 
 
 
111
 
112
  @st.cache_data
113
  def process_database_cached(data):
114
- database_embeddings = []
115
  database_info = []
116
  for item in data:
117
  image_url = item['이미지 링크'][0]
@@ -121,7 +112,6 @@ def process_database_cached(data):
121
  if image is None:
122
  continue
123
 
124
- # Save the image temporarily
125
  temp_path = f"temp_{product_id}.jpg"
126
  image.save(temp_path, 'JPEG')
127
 
@@ -140,17 +130,42 @@ def process_database_cached(data):
140
 
141
  def process_database(client, data):
142
  database_info = process_database_cached(data)
143
- database_embeddings = []
 
144
 
 
145
  for item in database_info:
146
- segmented_image = segment_image(item['temp_path'], client)
 
 
 
 
 
 
 
 
 
147
  embedding = get_image_embedding(segmented_image)
148
  database_embeddings.append(embedding)
 
149
 
150
  return np.vstack(database_embeddings), database_info
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  # Streamlit app
153
- st.title("Fashion Search App with Segmentation")
154
 
155
  # API Key input
156
  api_key = st.text_input("Enter your Roboflow API Key", type="password")
@@ -168,17 +183,18 @@ if api_key:
168
 
169
  if st.button('Find Similar Items'):
170
  with st.spinner('Processing...'):
171
- # Save uploaded image temporarily
172
  temp_path = "temp_upload.jpg"
173
  image.save(temp_path)
174
 
175
- # Segment the uploaded image
176
- segmented_image = segment_image(temp_path, CLIENT)
177
  st.image(segmented_image, caption='Segmented Image', use_column_width=True)
178
 
179
- # Get embedding for segmented image
 
 
 
180
  query_embedding = get_image_embedding(segmented_image)
181
- similar_images = find_similar_images(query_embedding)
182
 
183
  st.subheader("Similar Items:")
184
  for img in similar_images:
@@ -192,5 +208,9 @@ if api_key:
192
  st.write(f"Price: {img['info']['price']}")
193
  st.write(f"Discount: {img['info']['discount']}%")
194
  st.write(f"Similarity: {img['similarity']:.2f}")
 
 
 
 
195
  else:
196
  st.warning("Please enter your Roboflow API Key to use the app.")
 
11
  from inference_sdk import InferenceHTTPClient
12
  import matplotlib.pyplot as plt
13
  import base64
14
+ import os
15
+ import pickle
16
 
17
  # Load model and tokenizer
18
  @st.cache_resource
 
32
 
33
  data = load_data()
34
 
35
+ def setup_roboflow_client(api_key):
36
+ return InferenceHTTPClient(
37
+ api_url="https://outline.roboflow.com",
38
+ api_key=api_key
39
+ )
40
+
41
  def download_and_process_image(image_url):
42
  try:
43
  response = requests.get(image_url)
44
+ response.raise_for_status()
45
  image = Image.open(BytesIO(response.content))
 
 
46
  if image.mode == 'RGBA':
47
  image = image.convert('RGB')
 
48
  return image
 
 
 
49
  except Exception as e:
50
+ st.error(f"Error downloading/processing image: {str(e)}")
51
  return None
52
 
53
+ def segment_image_and_get_categories(image_path, client):
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  try:
 
55
  with open(image_path, "rb") as image_file:
56
  image_data = image_file.read()
57
 
 
58
  encoded_image = base64.b64encode(image_data).decode('utf-8')
59
 
 
60
  image = cv2.imread(image_path)
61
  image = cv2.resize(image, (800, 600))
62
  mask = np.zeros(image.shape, dtype=np.uint8)
63
 
 
64
  results = client.infer(encoded_image, model_id="closet/1")
65
 
 
66
  if isinstance(results, dict):
67
  predictions = results.get('predictions', [])
68
  else:
 
69
  predictions = json.loads(results).get('predictions', [])
70
 
71
+ categories = []
72
  if predictions:
73
  for prediction in predictions:
74
  points = prediction['points']
75
  pts = np.array([[p['x'], p['y']] for p in points], np.int32)
76
+ scale_x = image.shape[1] / results.get('image', {}).get('width', 1)
77
+ scale_y = image.shape[0] / results.get('image', {}).get('height', 1)
78
  pts = pts * [scale_x, scale_y]
79
  pts = pts.astype(np.int32)
80
  pts = pts.reshape((-1, 1, 2))
81
+ cv2.fillPoly(mask, [pts], color=(255, 255, 255))
82
+
83
+ category = prediction.get('class', 'Unknown')
84
+ confidence = prediction.get('confidence', 0)
85
+ categories.append(f"{category} ({confidence:.2f})")
86
 
87
  segmented_image = cv2.bitwise_and(image, mask)
88
  else:
89
  st.warning("No predictions found in the image. Returning original image.")
90
  segmented_image = image
91
 
92
+ return Image.fromarray(cv2.cvtColor(segmented_image, cv2.COLOR_BGR2RGB)), categories
93
  except Exception as e:
94
  st.error(f"Error in segmentation: {str(e)}")
95
+ return Image.open(image_path), []
96
+
97
+ def get_image_embedding(image):
98
+ image_tensor = preprocess_val(image).unsqueeze(0).to(device)
99
+ with torch.no_grad():
100
+ image_features = model.encode_image(image_tensor)
101
+ image_features /= image_features.norm(dim=-1, keepdim=True)
102
+ return image_features.cpu().numpy()
103
 
104
  @st.cache_data
105
  def process_database_cached(data):
 
106
  database_info = []
107
  for item in data:
108
  image_url = item['이미지 링크'][0]
 
112
  if image is None:
113
  continue
114
 
 
115
  temp_path = f"temp_{product_id}.jpg"
116
  image.save(temp_path, 'JPEG')
117
 
 
130
 
131
  def process_database(client, data):
132
  database_info = process_database_cached(data)
133
+ cache_dir = "segmentation_cache"
134
+ os.makedirs(cache_dir, exist_ok=True)
135
 
136
+ database_embeddings = []
137
  for item in database_info:
138
+ cache_file = os.path.join(cache_dir, f"{item['id']}_segmented.pkl")
139
+
140
+ if os.path.exists(cache_file):
141
+ with open(cache_file, 'rb') as f:
142
+ segmented_image, categories = pickle.load(f)
143
+ else:
144
+ segmented_image, categories = segment_image_and_get_categories(item['temp_path'], client)
145
+ with open(cache_file, 'wb') as f:
146
+ pickle.dump((segmented_image, categories), f)
147
+
148
  embedding = get_image_embedding(segmented_image)
149
  database_embeddings.append(embedding)
150
+ item['categories'] = categories
151
 
152
  return np.vstack(database_embeddings), database_info
153
 
154
+ def find_similar_images(query_embedding, database_embeddings, database_info, top_k=5):
155
+ similarities = np.dot(database_embeddings, query_embedding.T).squeeze()
156
+ top_indices = np.argsort(similarities)[::-1][:top_k]
157
+
158
+ results = []
159
+ for idx in top_indices:
160
+ results.append({
161
+ 'info': database_info[idx],
162
+ 'similarity': similarities[idx]
163
+ })
164
+
165
+ return results
166
+
167
  # Streamlit app
168
+ st.title("Fashion Search App with Segmentation and Category Detection")
169
 
170
  # API Key input
171
  api_key = st.text_input("Enter your Roboflow API Key", type="password")
 
183
 
184
  if st.button('Find Similar Items'):
185
  with st.spinner('Processing...'):
 
186
  temp_path = "temp_upload.jpg"
187
  image.save(temp_path)
188
 
189
+ segmented_image, input_categories = segment_image_and_get_categories(temp_path, CLIENT)
 
190
  st.image(segmented_image, caption='Segmented Image', use_column_width=True)
191
 
192
+ st.subheader("Detected Categories in Input Image:")
193
+ for category in input_categories:
194
+ st.write(category)
195
+
196
  query_embedding = get_image_embedding(segmented_image)
197
+ similar_images = find_similar_images(query_embedding, database_embeddings, database_info)
198
 
199
  st.subheader("Similar Items:")
200
  for img in similar_images:
 
208
  st.write(f"Price: {img['info']['price']}")
209
  st.write(f"Discount: {img['info']['discount']}%")
210
  st.write(f"Similarity: {img['similarity']:.2f}")
211
+
212
+ st.write("Detected Categories:")
213
+ for category in img['info']['categories']:
214
+ st.write(category)
215
  else:
216
  st.warning("Please enter your Roboflow API Key to use the app.")