Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import pandas as pd | |
| from html import escape | |
| import os | |
| import torch | |
| from transformers import RobertaModel, AutoTokenizer | |
| def load(): | |
| text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text') | |
| tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text') | |
| links = np.load('link.npy', allow_pickle=True) | |
| image_embeddings = torch.load('embeddings.pt') | |
| return text_encoder, tokenizer, links, image_embeddings | |
| text_encoder, tokenizer, links, image_embeddings = load() | |
| def get_html(url_list, height=224): | |
| html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" | |
| for url in url_list: | |
| html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>" | |
| html = html + html2 | |
| html += "</div>" | |
| return html | |
| def image_search(query, top_k=8): | |
| with torch.no_grad(): | |
| text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output | |
| values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True) | |
| return [links[i] for i in indices[:top_k]] | |
| description = ''' | |
| # Semantic image search :) | |
| ''' | |
| def main(): | |
| st.markdown(''' | |
| <style> | |
| .block-container{ | |
| max-width: 1200px; | |
| } | |
| div.row-widget.stRadio > div{ | |
| flex-direction:row; | |
| display: flex; | |
| justify-content: center; | |
| } | |
| div.row-widget.stRadio > div > label{ | |
| margin-left: 5px; | |
| margin-right: 5px; | |
| } | |
| section.main>div:first-child { | |
| padding-top: 0px; | |
| } | |
| section:not(.main)>div:first-child { | |
| padding-top: 30px; | |
| } | |
| div.reportview-container > section:first-child{ | |
| max-width: 320px; | |
| } | |
| #MainMenu { | |
| visibility: hidden; | |
| } | |
| footer { | |
| visibility: hidden; | |
| } | |
| </style>''', | |
| unsafe_allow_html=True) | |
| st.sidebar.markdown(description) | |
| _, c, _ = st.columns((1, 3, 1)) | |
| query = c.text_input('', value='clouds at sunset') | |
| if len(query) > 0: | |
| results = image_search(query) | |
| st.markdown(get_html(results), unsafe_allow_html=True) | |
| if __name__ == '__main__': | |
| main() |