This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_response_sql(user_query, chat_history, plot=False): | |
# Specify the path to the SQLite database | |
db_path = "metadataDB/output_database.db" | |
# Connect to the SQLite database | |
db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
underspecified = classify_underspecified_query(user_query , chat_history) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def get_response_sql(user_query, chat_history, plot=False): | |
# Specify the path to the SQLite database | |
db_path = "metadataDB/output_database.db" | |
# Connect to the SQLite database | |
db = SQLDatabase.from_uri(f"sqlite:///{db_path}") | |
underspecified = classify_underspecified_query(user_query , chat_history) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
from typing import Iterable | |
from collections import Counter | |
import os | |
import logging | |
import sys | |
import json | |
import click | |
import datasets | |
import numpy as np |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import argparse | |
from typing import Optional, Union, Tuple | |
import torch | |
torch.manual_seed(0) | |
from transformers import BertModel, BertTokenizer, PreTrainedModel, BertConfig | |
from transformers.modeling_outputs import MultipleChoiceModelOutput |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import os | |
all_types_to_idx = { | |
'Task': 0, | |
'Method': 1, | |
'Material': 2, | |
'Metric': 3, | |
'OtherScientificTerm': 4, | |
'Generic': 5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
from collections import Counter | |
from urllib.parse import urlparse | |
import json | |
import os | |
import re | |
from tqdm import tqdm | |
urls_counts = {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import List | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import seaborn as sns | |
import scipy | |
import random | |
class NormalGammaPrior(): | |
"""" | |
Suppose X is distributed according to a normal distribution: X ~ N(mu, tau^{-1}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import statistics as st | |
import scipy.stats | |
import numpy as np | |
def metric1(scores, row_aggregator, column_aggregator, cell_aggregator): | |
row_values = [] | |
for row_idx, row1 in enumerate(scores): | |
diagonal_x = row1[row_idx] | |
row_values.append( |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
show_unpublished_scores: true | |
datasets: | |
blind_labels: danielk/genie_labels | |
evaluator: | |
image: jbragg/genie-evaluator | |
input_path: /preds/ | |
predictions_filename: predictions.json | |
label_path: /labels/ | |
output_path: /results | |
arguments: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import statistics as st | |
def metric1(scores, row_aggregator, column_aggregator, cell_aggregator): | |
row_values = [] | |
for row_idx, row1 in enumerate(scores): | |
diagonal_x = row1[row_idx] | |
row_values.append( | |
column_aggregator( | |
[cell_aggregator(diagonal_x, x, abs(col_idx - row_idx)) for col_idx, x in enumerate(row1) if col_idx != row_idx] | |
) |
NewerOlder