class Node:
def __init__(self):
self.l = self.r = 0
self.s = self.mx = 0
class SegmentTree:
def __init__(self, n, m):
self.m = m
self.tr = [Node() for _ in range(n << 2)]
self.build(1, 1, n)
def build(self, u, l, r):
self.tr[u].l, self.tr[u].r = l, r
if l == r:
self.tr[u].s = self.tr[u].mx = self.m
return
mid = (l + r) >> 1
self.build(u << 1, l, mid)
self.build(u << 1 | 1, mid + 1, r)
self.pushup(u)
def modify(self, u, x, v):
if self.tr[u].l == x and self.tr[u].r == x:
self.tr[u].s = self.tr[u].mx = v
return
mid = (self.tr[u].l + self.tr[u].r) >> 1
if x <= mid:
self.modify(u << 1, x, v)
else:
self.modify(u << 1 | 1, x, v)
self.pushup(u)
def query_sum(self, u, l, r):
if self.tr[u].l >= l and self.tr[u].r <= r:
return self.tr[u].s
mid = (self.tr[u].l + self.tr[u].r) >> 1
v = 0
if l <= mid:
v += self.query_sum(u << 1, l, r)
if r > mid:
v += self.query_sum(u << 1 | 1, l, r)
return v
def query_idx(self, u, l, r, k):
if self.tr[u].mx < k:
return 0
if self.tr[u].l == self.tr[u].r:
return self.tr[u].l
mid = (self.tr[u].l + self.tr[u].r) >> 1
if self.tr[u << 1].mx >= k:
return self.query_idx(u << 1, l, r, k)
if r > mid:
return self.query_idx(u << 1 | 1, l, r, k)
return 0
def pushup(self, u):
self.tr[u].s = self.tr[u << 1].s + self.tr[u << 1 | 1].s
self.tr[u].mx = max(self.tr[u << 1].mx, self.tr[u << 1 | 1].mx)
class BookMyShow:
def __init__(self, n: int, m: int):
self.n = n
self.tree = SegmentTree(n, m)
def gather(self, k: int, maxRow: int) -> List[int]:
maxRow += 1
i = self.tree.query_idx(1, 1, maxRow, k)
if i == 0:
return []
s = self.tree.query_sum(1, i, i)
self.tree.modify(1, i, s - k)
return [i - 1, self.tree.m - s]
def scatter(self, k: int, maxRow: int) -> bool:
maxRow += 1
if self.tree.query_sum(1, 1, maxRow) < k:
return False
i = self.tree.query_idx(1, 1, maxRow, 1)
for j in range(i, self.n + 1):
s = self.tree.query_sum(1, j, j)
if s >= k:
self.tree.modify(1, j, s - k)
return True
k -= s
self.tree.modify(1, j, 0)
return True
# Your BookMyShow object will be instantiated and called as such:
# obj = BookMyShow(n, m)
# param_1 = obj.gather(k,maxRow)
# param_2 = obj.scatter(k,maxRow)