Update app.py
Browse files
app.py
CHANGED
|
@@ -12,15 +12,6 @@ is_stopped = False
|
|
| 12 |
seed = random.randint(0,100000)
|
| 13 |
setup_seed(seed)
|
| 14 |
|
| 15 |
-
device = torch.device("cpu")
|
| 16 |
-
vocab_mlm = create_vocab()
|
| 17 |
-
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
| 18 |
-
save_path = 'mlm-model-27.pt'
|
| 19 |
-
train_seqs = pd.read_csv('C0_seq.csv')
|
| 20 |
-
train_seq = train_seqs['Seq'].tolist()
|
| 21 |
-
model = torch.load(save_path, map_location=torch.device('cpu'))
|
| 22 |
-
model = model.to(device)
|
| 23 |
-
|
| 24 |
def temperature_sampling(logits, temperature):
|
| 25 |
logits = logits / temperature
|
| 26 |
probabilities = torch.softmax(logits, dim=-1)
|
|
@@ -32,11 +23,20 @@ def stop_generation():
|
|
| 32 |
is_stopped = True
|
| 33 |
return "Generation stopped."
|
| 34 |
|
| 35 |
-
def CTXGen(X1, X2, τ, g_num, length_range):
|
| 36 |
global is_stopped
|
| 37 |
is_stopped = False
|
| 38 |
start, end = length_range
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
msa_data = pd.read_csv('conoData_C0.csv')
|
| 41 |
msa = msa_data['Sequences'].tolist()
|
| 42 |
msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
|
|
@@ -158,6 +158,7 @@ with gr.Blocks() as demo:
|
|
| 158 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
| 159 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
| 160 |
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
|
|
|
|
| 161 |
with gr.Row():
|
| 162 |
start_button = gr.Button("Start Generation")
|
| 163 |
stop_button = gr.Button("Stop Generation")
|
|
@@ -166,7 +167,7 @@ with gr.Blocks() as demo:
|
|
| 166 |
with gr.Row():
|
| 167 |
output_file = gr.File(label="Download generated conotoxins")
|
| 168 |
|
| 169 |
-
start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range], outputs=[output_df, output_file])
|
| 170 |
stop_button.click(stop_generation, outputs=None)
|
| 171 |
|
| 172 |
demo.launch()
|
|
|
|
| 12 |
seed = random.randint(0,100000)
|
| 13 |
setup_seed(seed)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
def temperature_sampling(logits, temperature):
|
| 16 |
logits = logits / temperature
|
| 17 |
probabilities = torch.softmax(logits, dim=-1)
|
|
|
|
| 23 |
is_stopped = True
|
| 24 |
return "Generation stopped."
|
| 25 |
|
| 26 |
+
def CTXGen(X1, X2, τ, g_num, length_range, model_name):
|
| 27 |
global is_stopped
|
| 28 |
is_stopped = False
|
| 29 |
start, end = length_range
|
| 30 |
|
| 31 |
+
device = torch.device("cpu")
|
| 32 |
+
vocab_mlm = create_vocab()
|
| 33 |
+
vocab_mlm = add_tokens_to_vocab(vocab_mlm)
|
| 34 |
+
save_path = model_name
|
| 35 |
+
train_seqs = pd.read_csv('C0_seq.csv')
|
| 36 |
+
train_seq = train_seqs['Seq'].tolist()
|
| 37 |
+
model = torch.load(save_path, map_location=torch.device('cpu'))
|
| 38 |
+
model = model.to(device)
|
| 39 |
+
|
| 40 |
msa_data = pd.read_csv('conoData_C0.csv')
|
| 41 |
msa = msa_data['Sequences'].tolist()
|
| 42 |
msa = [x for x in msa if x.startswith(f"{X1}|{X2}")]
|
|
|
|
| 158 |
τ = gr.Slider(minimum=1, maximum=2, step=0.1, label="τ")
|
| 159 |
g_num = gr.Dropdown(choices=[1, 10, 20, 30, 40, 50], label="Number of generations")
|
| 160 |
length_range = RangeSlider(minimum=8, maximum=50, step=1, value=(12, 16), label="Length range")
|
| 161 |
+
model_name = gr.Dropdown(choices=['model_final.pt','model_C1.pt','model_C2.pt','model_C3.pt','model_C4.pt','model_C5.pt','model_mlm.pt'], label="Model")
|
| 162 |
with gr.Row():
|
| 163 |
start_button = gr.Button("Start Generation")
|
| 164 |
stop_button = gr.Button("Stop Generation")
|
|
|
|
| 167 |
with gr.Row():
|
| 168 |
output_file = gr.File(label="Download generated conotoxins")
|
| 169 |
|
| 170 |
+
start_button.click(CTXGen, inputs=[X1, X2, τ, g_num, length_range,model_name], outputs=[output_df, output_file])
|
| 171 |
stop_button.click(stop_generation, outputs=None)
|
| 172 |
|
| 173 |
demo.launch()
|