Skip to content

Instantly share code, notes, and snippets.

@DDoSolitary
Created June 11, 2021 10:02
Show Gist options
  • Save DDoSolitary/3daacb30015f7fd14c0e4fa4ca751bbe to your computer and use it in GitHub Desktop.
Save DDoSolitary/3daacb30015f7fd14c0e4fa4ca751bbe to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
from __future__ import annotations
import functools
import itertools
import json
import multiprocessing
import random
import re
import string
import subprocess
import tempfile
import networkx as nx
from abc import ABC
from argparse import ArgumentParser
from collections import defaultdict, deque
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
REQUEST_COUNT = 1000
NAME_LEN = 10
NAME_ALPHABET = string.ascii_letters + string.digits
NOTICE_LEN = 100
NOTICE_ALPHABET = NAME_ALPHABET
MAX_ID = 1 << 31
MIN_ID = -MAX_ID
MAX_AGE = 200 + 1
MIN_AGE = 0
MAX_VALUE = 1000 + 1
MIN_VALUE = 0
MAX_SOCIAL_VALUE = 1000 + 1
MIN_SOCIAL_VALUE = -1000
MAX_EMOJI_ID = 10000 + 1
MIN_EMOJI_ID = 0
MAX_MONEY = 200 + 1
MIN_MONEY = 0
MAX_HEAT_LIMIT = 10
MIN_HEAT_LIMIT = 0
class InputRequest(ABC):
op: str
def __init__(self, op: str):
self.op = op
@staticmethod
def parse(s: str) -> InputRequest:
fields = s.split()
op = fields[0]
if op == 'ap':
return AddPersonRequest(int(fields[1]), fields[2], int(fields[3]))
elif op == 'ar':
return AddRelationRequest(int(fields[1]), int(fields[2]), int(fields[3]))
elif op == 'qv':
return QueryValueRequest(int(fields[1]), int(fields[2]))
elif op == 'cn':
return CompareNameRequest(int(fields[1]), int(fields[2]))
elif op == 'qnr':
return QueryNameRankRequest(int(fields[1]))
elif op == 'qps':
return QueryPeopleSumRequest()
elif op == 'qci':
return QueryCircleRequest(int(fields[1]), int(fields[2]))
elif op == 'qbs':
return QueryBlockSumRequest()
elif op == 'ag':
return AddGroupRequest(int(fields[1]))
elif op == 'atg':
return AddToGroupRequest(int(fields[1]), int(fields[2]))
elif op == 'qgs':
return QueryGroupSumRequest()
elif op == 'qgps':
return QueryGroupPeopleSumRequest(int(fields[1]))
elif op == 'qgvs':
return QueryGroupValueSumRequest(int(fields[1]))
elif op == 'qgam':
return QueryGroupAgeMeanRequest(int(fields[1]))
elif op == 'qgav':
return QueryGroupAgeVarRequest(int(fields[1]))
elif op == 'dfg':
return DelFromGroupRequest(int(fields[1]), int(fields[2]))
elif op == 'am':
return AddMessageRequest(int(fields[1]), int(fields[2]), int(fields[3]), int(fields[4]), int(fields[5]))
elif op == 'sm':
return SendMessageRequest(int(fields[1]))
elif op == 'qsv':
return QuerySocialValueRequest(int(fields[1]))
elif op == 'qrm':
return QueryReceivedMessagesRequest(int(fields[1]))
elif op == 'arem':
return AddRedEnvelopeMessageRequest(
int(fields[1]), int(fields[2]), int(fields[3]), int(fields[4]), int(fields[5]))
elif op == 'anm':
return AddNoticeMessageRequest(
int(fields[1]), fields[2], int(fields[3]), int(fields[4]), int(fields[5]))
elif op == 'aem':
return AddEmojiMessageRequest(
int(fields[1]), int(fields[2]), int(fields[3]), int(fields[4]), int(fields[5]))
elif op == 'sei':
return StoreEmojiIdRequest(int(fields[1]))
elif op == 'qp':
return QueryPopularityRequest(int(fields[1]))
elif op == 'dce':
return DeleteColdEmojiRequest(int(fields[1]))
elif op == 'qm':
return QueryMoneyRequest(int(fields[1]))
elif op == 'sim':
return SendIndirectMessageRequest(int(fields[1]))
else:
raise ValueError('invalid op')
class AddPersonRequest(InputRequest):
id: int
name: str
age: int
def __init__(self, _id: int, name: str, age: int):
super().__init__('ap')
self.id = _id
self.name = name
self.age = age
def __str__(self) -> str:
return f'{self.op} {self.id} {self.name} {self.age}'
class AddRelationRequest(InputRequest):
id1: int
id2: int
value: int
def __init__(self, id1: int, id2: int, value: int):
super().__init__('ar')
self.id1 = id1
self.id2 = id2
self.value = value
def __str__(self) -> str:
return f'{self.op} {self.id1} {self.id2} {self.value}'
class QueryValueRequest(InputRequest):
id1: int
id2: int
def __init__(self, id1: int, id2: int):
super().__init__('qv')
self.id1 = id1
self.id2 = id2
def __str__(self) -> str:
return f'{self.op} {self.id1} {self.id2}'
class CompareNameRequest(InputRequest):
id1: int
id2: int
def __init__(self, id1: int, id2: int):
super().__init__('cn')
self.id1 = id1
self.id2 = id2
def __str__(self) -> str:
return f'{self.op} {self.id1} {self.id2}'
class QueryNameRankRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('qnr')
self.id = _id
def __str__(self) -> str:
return f'{self.op} {self.id}'
class QueryPeopleSumRequest(InputRequest):
def __init__(self):
super().__init__('qps')
def __str__(self) -> str:
return self.op
class QueryCircleRequest(InputRequest):
id1: int
id2: int
def __init__(self, id1: int, id2: int):
super().__init__('qci')
self.id1 = id1
self.id2 = id2
def __str__(self) -> str:
return f'{self.op} {self.id1} {self.id2}'
class QueryBlockSumRequest(InputRequest):
def __init__(self):
super().__init__('qbs')
def __str__(self) -> str:
return self.op
class AddGroupRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('ag')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class AddToGroupRequest(InputRequest):
id1: int
id2: int
def __init__(self, id1: int, id2: int):
super().__init__('atg')
self.id1 = id1
self.id2 = id2
def __str__(self):
return f'{self.op} {self.id1} {self.id2}'
class QueryGroupSumRequest(InputRequest):
def __init__(self):
super().__init__('qgs')
def __str__(self) -> str:
return self.op
class QueryGroupPeopleSumRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('qgps')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class QueryGroupValueSumRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('qgvs')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class QueryGroupAgeMeanRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('qgam')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class QueryGroupAgeVarRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('qgav')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class DelFromGroupRequest(InputRequest):
id1: int
id2: int
def __init__(self, id1: int, id2: int):
super().__init__('dfg')
self.id1 = id1
self.id2 = id2
def __str__(self):
return f'{self.op} {self.id1} {self.id2}'
class AddMessageRequest(InputRequest):
id: int
social_value: int
type: int
id1: int
id2: int
def __init__(self, _id: int, social_value: int, _type: int, id1: int, id2: int):
super().__init__('am')
self.id = _id
self.social_value = social_value
self.type = _type
self.id1 = id1
self.id2 = id2
def __str__(self):
return f'{self.op} {self.id} {self.social_value} {self.type} {self.id1} {self.id2}'
class SendMessageRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('sm')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class QuerySocialValueRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('qsv')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class QueryReceivedMessagesRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('qrm')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class AddRedEnvelopeMessageRequest(InputRequest):
id: int
money: int
type: int
id1: int
id2: int
def __init__(self, _id: int, money: int, _type: int, id1: int, id2: int):
super().__init__('arem')
self.id = _id
self.money = money
self.type = _type
self.id1 = id1
self.id2 = id2
def __str__(self):
return f'{self.op} {self.id} {self.money} {self.type} {self.id1} {self.id2}'
class AddNoticeMessageRequest(InputRequest):
id: int
string: str
type: int
id1: int
id2: int
def __init__(self, _id: int, _string: str, _type: int, id1: int, id2: int):
super().__init__('anm')
self.id = _id
self.string = _string
self.type = _type
self.id1 = id1
self.id2 = id2
def __str__(self):
return f'{self.op} {self.id} {self.string} {self.type} {self.id1} {self.id2}'
class AddEmojiMessageRequest(InputRequest):
id: int
emoji_id: int
type: int
id1: int
id2: int
def __init__(self, _id: int, emoji_id: int, _type: int, id1: int, id2: int):
super().__init__('aem')
self.id = _id
self.emoji_id = emoji_id
self.type = _type
self.id1 = id1
self.id2 = id2
def __str__(self):
return f'{self.op} {self.id} {self.emoji_id} {self.type} {self.id1} {self.id2}'
class StoreEmojiIdRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('sei')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class QueryPopularityRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('qp')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class DeleteColdEmojiRequest(InputRequest):
limit: int
def __init__(self, limit: int):
super().__init__('dce')
self.limit = limit
def __str__(self):
return f'{self.op} {self.limit}'
class QueryMoneyRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('qm')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
class SendIndirectMessageRequest(InputRequest):
id: int
def __init__(self, _id: int):
super().__init__('sim')
self.id = _id
def __str__(self):
return f'{self.op} {self.id}'
@dataclass
class Person:
id: int
name: str
age: int
social_value: int
money: int
messages: deque[Message]
def __hash__(self) -> int:
return hash(id)
def __eq__(self, other: object) -> bool:
if not isinstance(other, Person):
return NotImplemented
return self.id == other.id
@dataclass
class Group:
id: int
members: set[Person]
@dataclass
class Message:
id: int
social_value: int
src: Person
dst: Union[Person, Group]
def __str__(self):
return 'Ordinary message'
@dataclass
class NoticeMessage(Message):
notice: str
def __str__(self):
return f'notice: {self.notice}'
@dataclass
class EmojiMessage(Message):
emoji_id: int
def __str__(self):
return f'Emoji: {self.emoji_id}'
@dataclass
class RedEnvelopeMessage(Message):
money: int
def __str__(self):
return f'RedEnvelope: {self.money}'
class ExceptionCounter:
tag: str
global_counter: int
person_counters: defaultdict[int, int]
def __init__(self, tag: str):
self.tag = tag
self.global_counter = 0
self.person_counters = defaultdict(int)
def count(self, _id: int) -> str:
self.global_counter += 1
self.person_counters[_id] += 1
return f'{self.tag}-{self.global_counter}, {_id}-{self.person_counters[_id]}'
def count_both(self, id1: int, id2: int):
self.global_counter += 1
if id2 < id1:
id1, id2 = id2, id1
self.person_counters[id1] += 1
if id1 != id2:
self.person_counters[id2] += 1
return f'{self.tag}-{self.global_counter}, {id1}-{self.person_counters[id1]}, {id2}-{self.person_counters[id2]}'
class PeopleNetwork:
people: dict[int, Person]
groups: dict[int, Group]
messages: dict[int, Message]
emojis: dict[int, int]
graph: nx.Graph
pinfCounter: ExceptionCounter
epiCounter: ExceptionCounter
rnfCounter: ExceptionCounter
erCounter: ExceptionCounter
ginfCounter: ExceptionCounter
egiCounter: ExceptionCounter
minfCounter: ExceptionCounter
emiCounter: ExceptionCounter
einfCounter: ExceptionCounter
eeiCounter: ExceptionCounter
def __init__(self):
self.people = dict()
self.groups = dict()
self.messages = dict()
self.emojis = dict()
self.graph = nx.Graph()
self.pinfCounter = ExceptionCounter('pinf')
self.epiCounter = ExceptionCounter('epi')
self.rnfCounter = ExceptionCounter('rnf')
self.erCounter = ExceptionCounter('er')
self.ginfCounter = ExceptionCounter('ginf')
self.egiCounter = ExceptionCounter('egi')
self.minfCounter = ExceptionCounter('minf')
self.emiCounter = ExceptionCounter('emi')
self.einfCounter = ExceptionCounter('einf')
self.eeiCounter = ExceptionCounter('eei')
def process_request(self, req: InputRequest) -> str:
if isinstance(req, AddPersonRequest):
if req.id in self.people:
return self.epiCounter.count(req.id)
p = Person(req.id, req.name, req.age, 0, 0, deque())
self.people[req.id] = p
self.graph.add_edge(p, p, value=0)
return 'Ok'
elif isinstance(req, AddRelationRequest):
p1 = self.people.get(req.id1)
if p1 is None:
return self.pinfCounter.count(req.id1)
p2 = self.people.get(req.id2)
if p2 is None:
return self.pinfCounter.count(req.id2)
if self.graph.has_edge(p1, p2):
return self.erCounter.count_both(req.id1, req.id2)
self.graph.add_edge(p1, p2, value=req.value)
return 'Ok'
elif isinstance(req, QueryValueRequest):
p1 = self.people.get(req.id1)
if p1 is None:
return self.pinfCounter.count(req.id1)
p2 = self.people.get(req.id2)
if p2 is None:
return self.pinfCounter.count(req.id2)
attrs = self.graph.get_edge_data(p1, p2)
if attrs is None:
return self.rnfCounter.count_both(req.id1, req.id2)
return str(attrs['value'])
elif isinstance(req, CompareNameRequest):
p1 = self.people.get(req.id1)
if p1 is None:
return self.pinfCounter.count(req.id1)
p2 = self.people.get(req.id2)
if p2 is None:
return self.pinfCounter.count(req.id2)
if p1.name < p2.name:
return '<'
if p1.name == p2.name:
return '='
return '>'
elif isinstance(req, QueryNameRankRequest):
p = self.people.get(req.id)
if p is None:
return self.pinfCounter.count(req.id)
return str(sum(1 for o in self.people.values() if o.name < p.name) + 1)
elif isinstance(req, QueryPeopleSumRequest):
return str(len(self.people))
elif isinstance(req, QueryCircleRequest):
p1 = self.people.get(req.id1)
if p1 is None:
return self.pinfCounter.count(req.id1)
p2 = self.people.get(req.id2)
if p2 is None:
return self.pinfCounter.count(req.id2)
return '1' if nx.has_path(self.graph, p1, p2) else '0'
elif isinstance(req, QueryBlockSumRequest):
return str(nx.number_connected_components(self.graph))
elif isinstance(req, AddGroupRequest):
if req.id in self.groups:
return self.egiCounter.count(req.id)
self.groups[req.id] = Group(req.id, set())
return 'Ok'
elif isinstance(req, AddToGroupRequest):
g = self.groups.get(req.id2)
if g is None:
return self.ginfCounter.count(req.id2)
p = self.people.get(req.id1)
if p is None:
return self.pinfCounter.count(req.id1)
if p in g.members:
return self.epiCounter.count(req.id1)
g.members.add(p)
return 'Ok'
elif isinstance(req, QueryGroupSumRequest):
return str(len(self.groups))
elif isinstance(req, QueryGroupPeopleSumRequest):
g = self.groups.get(req.id)
if g is None:
return self.ginfCounter.count(req.id)
return str(len(g.members))
elif isinstance(req, QueryGroupValueSumRequest):
g = self.groups.get(req.id)
if g is None:
return self.ginfCounter.count(req.id)
return str(2 * sum(v for _, _, v in self.graph.subgraph(g.members).edges.data('value')))
elif isinstance(req, QueryGroupAgeMeanRequest):
g = self.groups.get(req.id)
if g is None:
return self.ginfCounter.count(req.id)
if len(g.members) == 0:
return '0'
return str(sum(p.age for p in g.members) // len(g.members))
elif isinstance(req, QueryGroupAgeVarRequest):
g = self.groups.get(req.id)
if g is None:
return self.ginfCounter.count(req.id)
if len(g.members) == 0:
return '0'
mean = sum(p.age for p in g.members) // len(g.members)
return str(sum((p.age - mean) ** 2 for p in g.members) // len(g.members))
elif isinstance(req, DelFromGroupRequest):
g = self.groups.get(req.id2)
if g is None:
return self.ginfCounter.count(req.id2)
p = self.people.get(req.id1)
if p is None:
return self.pinfCounter.count(req.id1)
if p not in g.members:
return self.epiCounter.count(req.id1)
g.members.remove(p)
return 'Ok'
elif isinstance(req, AddMessageRequest) or isinstance(req, AddRedEnvelopeMessageRequest) or \
isinstance(req, AddNoticeMessageRequest) or isinstance(req, AddEmojiMessageRequest):
if req.type == 1 and req.id2 not in self.groups:
return 'Group does not exist'
if req.id1 not in self.people or (req.type == 0 and req.id2 not in self.people):
return 'The person with this number does not exist'
if req.id in self.messages:
return self.emiCounter.count(req.id)
src = self.people[req.id1]
dst = self.people[req.id2] if req.type == 0 else self.groups[req.id2]
msg: Message
if isinstance(req, AddMessageRequest):
msg = Message(req.id, req.social_value, src, dst)
elif isinstance(req, AddRedEnvelopeMessageRequest):
msg = RedEnvelopeMessage(req.id, req.money * 5, src, dst, req.money)
elif isinstance(req, AddNoticeMessageRequest):
msg = NoticeMessage(req.id, len(req.string), src, dst, req.string)
elif isinstance(req, AddEmojiMessageRequest):
if req.emoji_id not in self.emojis:
return self.einfCounter.count(req.emoji_id)
msg = EmojiMessage(req.id, req.emoji_id, src, dst, req.emoji_id)
else:
raise Exception('unreachable code')
if isinstance(dst, Person) and src == dst:
return self.epiCounter.count(src.id)
self.messages[msg.id] = msg
return 'Ok'
elif isinstance(req, SendMessageRequest):
m = self.messages.get(req.id)
if m is None:
return self.minfCounter.count(req.id)
if isinstance(m.dst, Person):
if not self.graph.has_edge(m.src, m.dst):
return self.rnfCounter.count_both(m.src.id, m.dst.id)
m.src.social_value += m.social_value
m.dst.social_value += m.social_value
if isinstance(m, RedEnvelopeMessage):
m.src.money -= m.money
m.dst.money += m.money
m.dst.messages.appendleft(m)
else:
if m.src not in m.dst.members:
return self.pinfCounter.count(m.src.id)
for gp in m.dst.members:
gp.social_value += m.social_value
if isinstance(m, RedEnvelopeMessage):
money = m.money // len(m.dst.members)
m.src.money -= money * (len(m.dst.members) - 1)
for gp in m.dst.members:
if gp != m.src:
gp.money += money
if isinstance(m, EmojiMessage):
self.emojis[m.emoji_id] += 1
self.messages.pop(m.id)
return 'Ok'
elif isinstance(req, QuerySocialValueRequest):
p = self.people.get(req.id)
if p is None:
return self.pinfCounter.count(req.id)
return str(p.social_value)
elif isinstance(req, QueryReceivedMessagesRequest):
p = self.people.get(req.id)
if p is None:
return self.pinfCounter.count(req.id)
if len(p.messages) == 0:
return 'None'
return '; '.join(str(msg) for msg in itertools.islice(p.messages, 0, 4))
elif isinstance(req, StoreEmojiIdRequest):
if req.id in self.emojis:
return self.eeiCounter.count(req.id)
self.emojis[req.id] = 0
return 'Ok'
elif isinstance(req, QueryPopularityRequest):
if req.id not in self.emojis:
return self.einfCounter.count(req.id)
return str(self.emojis[req.id])
elif isinstance(req, DeleteColdEmojiRequest):
self.emojis = {k: v for k, v in self.emojis.items() if v >= req.limit}
self.messages = {
k: v for k, v in self.messages.items()
if isinstance(v, EmojiMessage) and v.emoji_id in self.emojis}
return str(len(self.emojis))
elif isinstance(req, QueryMoneyRequest):
p = self.people.get(req.id)
if p is None:
return self.pinfCounter.count(req.id)
return str(p.money)
elif isinstance(req, SendIndirectMessageRequest):
m = self.messages.get(req.id)
if m is None or not isinstance(m.dst, Person):
return self.minfCounter.count(req.id)
try:
dis = nx.shortest_path_length(self.graph, m.src, m.dst, 'value')
except nx.NetworkXNoPath:
return '-1'
m.src.social_value += m.social_value
m.dst.social_value += m.social_value
if isinstance(m, RedEnvelopeMessage):
m.src.money -= m.money
m.dst.money += m.money
elif isinstance(m, EmojiMessage):
self.emojis[m.emoji_id] += 1
m.dst.messages.appendleft(m)
self.messages.pop(m.id)
return str(dis)
else:
raise ValueError('invalid request')
@dataclass
class TestError(Exception):
reason: str
input: str
ans: list[str]
proc: Optional[subprocess.CompletedProcess]
err_line: Optional[int]
def gen_test_case() -> (list[InputRequest], list[str]):
requests: list[InputRequest] = []
responses: list[str] = []
network = PeopleNetwork()
def gen_id() -> int:
return random.randrange(MIN_ID, MAX_ID)
def gen_used_person_id() -> int:
return random.choice(tuple(network.people.keys()))
def gen_new_person_id() -> int:
if random.randrange(10) == 0 and len(network.people) > 0:
return gen_used_person_id()
return gen_id()
def gen_person_id() -> int:
if random.randrange(10) == 0 or len(network.people) == 0:
return gen_id()
return gen_used_person_id()
def gen_person_id_pair() -> (int, int):
rand = random.randrange(10)
if rand == 0 or len(network.people) == 0:
if random.randrange(5) == 0:
ret = gen_id()
return ret, ret
return gen_id(), gen_id()
if rand == 1:
id1 = gen_used_person_id()
id2 = gen_id()
if random.randrange(0, 2) == 0:
id1, id2 = id2, id1
return id1, id2
if random.randrange(5) == 0:
ret = gen_used_person_id()
return ret, ret
return gen_used_person_id(), gen_used_person_id()
def gen_used_group_id() -> int:
return random.choice(tuple(network.groups.keys()))
def gen_new_group_id() -> int:
if random.randrange(10) == 0 and len(network.groups) > 0:
return gen_used_group_id()
return gen_id()
def gen_group_id() -> int:
if random.randrange(10) == 0 or len(network.groups) == 0:
return gen_id()
return gen_used_group_id()
def gen_used_msg_id() -> int:
return random.choice(tuple(network.messages.keys()))
def gen_new_msg_id() -> int:
if random.randrange(10) == 0 and len(network.messages) > 0:
return gen_used_msg_id()
return gen_id()
def gen_msg_id() -> int:
if random.randrange(10) == 0 or len(network.messages) == 0:
return gen_id()
return gen_used_msg_id()
def gen_name() -> str:
if random.randrange(10) == 0 and len(network.people) > 0:
return random.choice(tuple(p.name for p in network.people.values()))
ret = ''.join(random.choice(NAME_ALPHABET) for _ in range(NAME_LEN))
return ret
def gen_age() -> int:
return random.randrange(MIN_AGE, MAX_AGE)
def gen_value() -> int:
return random.randrange(MIN_VALUE, MAX_VALUE)
def gen_social_value() -> int:
return random.randrange(MIN_SOCIAL_VALUE, MAX_SOCIAL_VALUE)
def gen_used_emoji_id() -> int:
return random.choice(tuple(network.emojis.keys()))
def gen_new_emoji_id() -> int:
if random.randrange(10) == 0 and len(network.emojis) > 0:
return gen_used_emoji_id()
return random.randrange(MIN_EMOJI_ID, MAX_EMOJI_ID)
def gen_emoji_id() -> int:
if random.randrange(10) == 0 or len(network.emojis) == 0:
return random.randrange(MIN_EMOJI_ID, MAX_EMOJI_ID)
return gen_used_emoji_id()
def gen_money() -> int:
return random.randrange(MIN_MONEY, MAX_MONEY)
def gen_notice() -> str:
return ''.join(random.choice(NOTICE_ALPHABET) for _ in range(NOTICE_LEN))
def gen_heat_limit() -> int:
return random.randrange(MIN_HEAT_LIMIT, MAX_HEAT_LIMIT)
def gen_ap() -> AddPersonRequest:
return AddPersonRequest(gen_new_person_id(), gen_name(), gen_age())
def gen_ar() -> AddRelationRequest:
return AddRelationRequest(*gen_person_id_pair(), gen_value())
def gen_qv() -> QueryValueRequest:
return QueryValueRequest(*gen_person_id_pair())
def gen_cn() -> CompareNameRequest:
return CompareNameRequest(*gen_person_id_pair())
def gen_qnr() -> QueryNameRankRequest:
return QueryNameRankRequest(gen_person_id())
def gen_qps() -> QueryPeopleSumRequest:
return QueryPeopleSumRequest()
def gen_qci() -> QueryCircleRequest:
return QueryCircleRequest(*gen_person_id_pair())
def gen_qbs() -> QueryBlockSumRequest:
return QueryBlockSumRequest()
def gen_ag() -> AddGroupRequest:
return AddGroupRequest(gen_new_group_id())
def gen_atg() -> AddToGroupRequest:
return AddToGroupRequest(gen_person_id(), gen_group_id())
def gen_qgs() -> QueryGroupSumRequest:
return QueryGroupSumRequest()
def gen_qgps() -> QueryGroupPeopleSumRequest:
return QueryGroupPeopleSumRequest(gen_group_id())
def gen_qgvs() -> QueryGroupValueSumRequest:
return QueryGroupValueSumRequest(gen_group_id())
def gen_qgam() -> QueryGroupAgeMeanRequest:
return QueryGroupAgeMeanRequest(gen_group_id())
def gen_qgav() -> QueryGroupAgeVarRequest:
return QueryGroupAgeVarRequest(gen_group_id())
def gen_dfg() -> DelFromGroupRequest:
return DelFromGroupRequest(gen_person_id(), gen_group_id())
def gen_am() -> Union[
AddMessageRequest, AddRedEnvelopeMessageRequest,
AddNoticeMessageRequest, AddEmojiMessageRequest]:
_type = random.randrange(2)
if _type == 0:
id1 = gen_person_id()
p1 = network.people.get(id1)
if p1 is not None:
p2_list = tuple(network.graph.neighbors(p1))
else:
p2_list = ()
if random.randrange(10) == 0 or len(p2_list) == 0:
id2 = gen_person_id()
else:
id2 = random.choice(p2_list).id
else:
id2 = gen_group_id()
g = network.groups.get(id2)
if random.randrange(10) == 0 or g is None or len(g.members) == 0:
id1 = gen_person_id()
else:
id1 = random.choice(tuple(g.members)).id
_id = gen_new_msg_id()
cls = random.randrange(4)
if cls == 0:
return AddMessageRequest(_id, gen_social_value(), _type, id1, id2)
elif cls == 1:
return AddRedEnvelopeMessageRequest(_id, gen_money(), _type, id1, id2)
elif cls == 2:
return AddNoticeMessageRequest(_id, gen_notice(), _type, id1, id2)
else:
return AddEmojiMessageRequest(_id, gen_emoji_id(), _type, id1, id2)
def gen_sm() -> SendMessageRequest:
return SendMessageRequest(gen_msg_id())
def gen_qsv() -> QuerySocialValueRequest:
return QuerySocialValueRequest(gen_person_id())
def gen_qrm() -> QueryReceivedMessagesRequest:
return QueryReceivedMessagesRequest(gen_person_id())
def gen_sei() -> StoreEmojiIdRequest:
return StoreEmojiIdRequest(gen_new_emoji_id())
def gen_qp() -> QueryPopularityRequest:
return QueryPopularityRequest(gen_emoji_id())
def gen_dce() -> DeleteColdEmojiRequest:
return DeleteColdEmojiRequest(gen_heat_limit())
def gen_qm() -> QueryMoneyRequest:
return QueryMoneyRequest(gen_person_id())
def gen_sim() -> SendIndirectMessageRequest:
return SendIndirectMessageRequest(gen_msg_id())
for _ in range(REQUEST_COUNT):
req: InputRequest = random.choice((
gen_ap, gen_ar, gen_qv, gen_cn, gen_qnr, gen_qps, gen_qci, gen_qbs,
gen_ag, gen_atg, gen_qgs, gen_qgps, gen_qgvs, gen_qgam, gen_qgav, gen_dfg,
gen_am, gen_sm, gen_qsv, gen_qrm, gen_sei, gen_qp, gen_dce, gen_qm, gen_sim))()
requests.append(req)
responses.append(network.process_request(req))
return requests, responses
def parse_input_data(input_data: list[str]) -> (list[InputRequest], list[str]):
requests: list[InputRequest] = []
responses: list[str] = []
network = PeopleNetwork()
for line in input_data:
line = line.strip()
if len(line) == 0:
continue
req = InputRequest.parse(line)
requests.append(req)
responses.append(network.process_request(req))
return requests, responses
def do_test(_, config):
if config['input'] is None:
requests, responses = gen_test_case()
else:
requests, responses = parse_input_data(config['input'])
req_str = ''.join(f'{req}\n' for req in requests)
subjects = config['subjects']
errors = []
for subject in subjects:
try:
try:
proc = subprocess.run(
subject['cmd'],
input=req_str,
capture_output=True,
text=True,
timeout=10
)
except subprocess.TimeoutExpired:
raise TestError('Time Limit Exceeded', req_str, responses, None, None)
if proc.returncode != 0:
raise TestError('Runtime Error', req_str, responses, proc, None)
stdout = proc.stdout.splitlines()
err_line: Optional[int] = None
for i, (out_line, ans_line) in enumerate(itertools.zip_longest(stdout, responses)):
if out_line != ans_line:
err_line = i + 1
break
if err_line is not None:
raise TestError('Wrong Answer', req_str, responses, proc, err_line)
except TestError as e:
err = dict(
subject=subject['name'],
reason=e.reason
)
err_data = dict(
input=e.input,
ans=''.join(line + '\n' for line in e.ans)
)
if e.proc is not None:
if e.proc.stdout is not None:
err_data['stdout'] = e.proc.stdout
if e.proc.stderr is not None:
err_data['stderr'] = e.proc.stderr
err['exit_code'] = e.proc.returncode
if e.err_line is not None:
err['error_line'] = e.err_line
if filter_error(dict(**err, **err_data), config['filters']):
log_dir = Path(tempfile.mkdtemp(dir=config.get('log_dir')))
err['log_dir'] = str(log_dir)
for k, v in err_data.items():
(log_dir / f'{k}.txt').write_text(v)
errors.append(err)
return errors
def compile_rule(rule: dict):
for k in rule.keys():
if k != 'action':
rule[k] = re.compile(rule[k])
def filter_error(err, rules):
for rule in rules:
matched = True
for key, pattern in rule.items():
if key == 'action':
continue
value = err.get(key)
if value is None or pattern.search(value) is None:
matched = False
break
if matched:
action = rule['action']
if action == 'accept':
return True
elif action == 'ignore':
return False
return True
def main():
parser = ArgumentParser()
parser.add_argument('--config', '-c', required=True)
parser.add_argument('--log-dir', '-l')
group = parser.add_mutually_exclusive_group()
group.add_argument('--count', '-n', type=int, default=1)
group.add_argument('--input', '-i')
args = parser.parse_args()
with open(args.config) as f:
config = json.load(f)
config.update(vars(args))
if 'filters' in config:
for rule in config['filters']:
compile_rule(rule)
else:
config['filters'] = []
if config['input'] is not None:
with open(config['input']) as f:
config['input'] = f.readlines()
errors = do_test(None, config)
else:
errors = []
idx = 0
test_func = functools.partial(do_test, config=config)
with multiprocessing.Pool() as pool:
for res in pool.imap_unordered(test_func, range(args.count)):
print(f'#{idx}: {len(res)}')
idx += 1
errors.extend(res)
print(json.dumps(errors, indent=2))
if __name__ == '__main__':
main()
# vim: ts=4:sw=4:noet
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment