Submission #336637

#TimeUsernameProblemLanguageResultExecution timeMemory
336637ChrisT영역 (JOI16_ho_t4)C++17
100 / 100
382 ms180204 KiB
#include <bits/stdc++.h>
using namespace std;
const int MN = 5e5 + 5;
struct Stupid {
	map<long long,long long> row,row2;
	set<long long> dub,dub2,intersect;
	void add (long long i, long long v) {
		if (!row.count(i)) {
			row[i] += v;
			//assert(row[i] >= 0);
			if (row.count(i-1)) {
				dub.insert(i-1);
				if (dub2.count(i-1)) intersect.insert(i-1);
			}
			if (row.count(i+1)) {
				dub.insert(i);
				if (dub2.count(i)) intersect.insert(i);
			}
			return;
		}
		if (!(row[i] += v)) {
			row.erase(i);
			if (row.count(i-1)) {
				dub.erase(i-1);
				if (dub2.count(i-1)) intersect.erase(i-1);
			}
			if (row.count(i+1)) {
				dub.erase(i);
				if (dub2.count(i)) intersect.erase(i);
			}
		}
	}
	void add2 (long long i, long long v) {
		if (!row2.count(i)) {
			row2[i] += v;
			if (row2.count(i-1)) {
				dub2.insert(i-1);
				if (dub.count(i-1)) intersect.insert(i-1);
			}
			if (row2.count(i+1)) {
				dub2.insert(i);
				if (dub.count(i)) intersect.insert(i);
			}
			return;
		}
		if (!(row2[i] += v)) {
			row2.erase(i);
			if (row2.count(i-1)) {
				dub2.erase(i-1);
				if (dub.count(i-1)) intersect.erase(i-1);
			}
			if (row2.count(i+1)) {
				dub2.erase(i);
				if (dub.count(i)) intersect.erase(i);
			}
		}
	}
};
char s[MN];
int main() { 
	int n; long long k;
	scanf ("%d %lld\n%s",&n,&k,s+1);
	int dx = 0, dy = 0;
	for (int i = 1; i <= n; i++) {
		if (s[i] == 'E') dx++;
		else if (s[i] == 'W') dx--;
		else if (s[i] == 'S') dy++;
		else dy--;
	}
	if (dx < 0) {
		for (int i = 1; i <= n; i++) {
			if (s[i] == 'E') s[i] = 'W';
			else if (s[i] == 'W') s[i] = 'E';
		}
		dx = -dx;
	}
	if (dy < 0) {
		for (int i = 1; i <= n; i++) {
			if (s[i] == 'S') s[i] = 'N';
			else if (s[i] == 'N') s[i] = 'S';
		}
		dy = -dy;
	}
	if (dx == 0 && dy == 0) { //TODO special case stupid
		set<pair<int,int>> st;
		dx = 0, dy = 0; st.emplace(0,0);
		for (int i = 1; i <= n; i++) {
			if (s[i] == 'E') dx++;
			else if (s[i] == 'W') dx--;
			else if (s[i] == 'S') dy++;
			else dy--;
			st.emplace(dx,dy);
		}
		int res = 0;
		for (auto [x,y] : st) {
			if (st.count({x+1,y}) && st.count({x,y+1}) && st.count({x+1,y+1})) {
				res++;
			}
		}
		printf ("%d\n",res);
		return 0;
	}
	if (dy == 0) {
		for (int i = 1; i <= n; i++) {
			if (s[i] == 'E') s[i] = 'S';
			else if (s[i] == 'W') s[i] = 'N';
			else if (s[i] == 'S') s[i] = 'E';
			else s[i] = 'W';
		}
		swap(dx,dy);
	}
	//we have dy > 0, dx >= 0
	int top = 0, cur = 0;
	for (int i = 1; i <= n; i++) {
		if (s[i] == 'N') top = min(top,--cur);
		else if (s[i] == 'S') ++cur;
	}
	vector<map<long long,int>> monke(MN);
	vector<Stupid> row(MN);
	int curx = 0, cury = 0, mxy = -top;
	monke[-top][0]++;
	for (int i = 1; i <= n; i++) {
		if (s[i] == 'E') ++curx;
		else if (s[i] == 'W') --curx;
		else if (s[i] == 'N') --cury;
		else ++cury;
		mxy = max(mxy,cury - top);
		monke[cury - top][curx]++;
	}
	long long ret = 0, L = mxy + 2, R = k * dy - 1;
	vector<long long> lz(MN); vector<int> lst(MN);
	for (int i = 0; i < L; i++) {
		lst[i%dy] = i;
		if (i >= dy) {
			lz[i] = lz[i-dy] + dx; swap(row[i],row[i-dy]);
			for (auto [j,cnt] : monke[i]) row[i].add(j-lz[i],cnt);
			for (auto [j,cnt] : monke[i-1]) row[i].add2(j-lz[i],cnt);
		} else {
			for (auto [j,cnt] : monke[i])
				row[i].add(j,cnt);
			if (i) for (auto [j,cnt] : monke[i-1])
				row[i].add2(j,cnt);
		}
		if (i >= k * dy) {
			for (auto [j,cnt] : monke[i-dy*k]) {
				row[i].add(j + k * dx - lz[i],-cnt);
			}
			if (i - 1 >= k * dy) {
				for (auto [j,cnt] : monke[i - dy * k - 1]) {
					row[i].add2(j + k * dx - lz[i],-cnt);
				}
			}
		}
		if (!row[i].row.empty()) {
			ret += (int)row[i].intersect.size();
		}
	}
	//printf ("ret %lld\n",ret);
	for (int res = 0; res < dy; res++) {
		long long cnt = max(0LL,(R - res) / dy - (L - 1 - res) / dy);
		//printf ("res %d cnt %lld\n",res,cnt);
		ret += (long long)row[lst[res]].intersect.size() * cnt; lz[lst[res]] += dx * cnt;
	}
	//printf ("ret %lld\n",ret);
	for (long long i = max(L,R+1); i <= R + 2 + mxy; i++) {
		int idx = lst[i % dy]; lz[idx] += dx;
		//printf ("i %lld idx %d\n",i,idx);
		for (auto [j,cnt] : monke[i - dy * k]) {
			row[idx].add(j + k * dx - lz[idx],-cnt);
		}
		if (i >= dy * k + 1) for (auto [j,cnt] : monke[i - dy * k - 1]) {
			row[idx].add2(j + k * dx - lz[idx], -cnt);
		}
		if (!row[idx].row.empty()) {
			ret += (int)row[idx].intersect.size();
		}
	} 
	printf ("%lld\n",ret);
	return 0;
}

Compilation message (stderr)

2016_ho_t4.cpp: In function 'int main()':
2016_ho_t4.cpp:62:8: warning: ignoring return value of 'int scanf(const char*, ...)', declared with attribute warn_unused_result [-Wunused-result]
   62 |  scanf ("%d %lld\n%s",&n,&k,s+1);
      |  ~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...
#Verdict Execution timeMemoryGrader output
Fetching results...