#!/usr/bin/env python
# coding: utf-8

# Run in a 16GB kernel

# ## some imports

# In[44]:


# Import general python packages
import time
import numpy as np
import matplotlib.pyplot as plt
import pandas
from pandas.testing import assert_frame_equal
from astropy import units as u
from astropy.coordinates import SkyCoord
from astropy.io import ascii

# Import the Rubin TAP service utilities
from lsst.rsp import get_tap_service, retrieve_query


# ## check the TAP service

# In[2]:


service = get_tap_service("tap")
assert service is not None
assert service.baseurl == "https://data.lsst.cloud/api/tap"
print(service.maxrec)
print(service.hardlimit)


# ![image.png](attachment:8172b20c-8bb9-40a3-81ed-cfaac47b187b.png)

# ## Numbers
# So DiaObject has 41 million objects, DiaSource has 162 million rows and ForcedSourceOnDiaObject 17 billion.
# As the TAP limit is 134 miilion we can get all the IDs and nDiaSources in one go

# In[3]:


results = service.search("SELECT diaObjectId,nDiaSources from dp02_dc2_catalogs.DiaObject")


# query above takes about 5 minutes.
# Sort by nDiaSources and diaObjectId so that we pull out the objects with the most observations first and convert to dataframe (could sort in qserv but ...)

# In[4]:


sortedDF=results.to_table().to_pandas(index=False).sort_values(by=['nDiaSources','diaObjectId'], \
                                                               ascending=[False,True])


# To try and figure out the amount of data and build queries to get similar amounts back look at the cummalative totals
# and cut up into chunks of 50,000 DIASources and get the inidices in the dataframe

# In[50]:


cumulative_sum = 0
reset_indices = []
cumulative_threshold = 50000

for i, value in enumerate(sortedDF['nDiaSources']):
    cumulative_sum += value
    
    if cumulative_sum > cumulative_threshold:
        reset_indices.append(i)
       # print(cumulative_sum)
        cumulative_sum = 0

# If the last reset didn't occur at the end of the DataFrame, add the last index
if reset_indices[-1] != len(sortedDF) - 1:
    reset_indices.append(len(sortedDF) - 1)

# Display the reset indices
print("Reset Indices:", reset_indices)


# Now cycle through the indicies and build queries for each of the 3 tables based on DiaObjectId in (....) clause

# ## the where clause

# In[53]:


first=True
prev_nDiaSources=-9999
prev_diaObjectId=-9999
next_nDiaSources=-9999
next_diaObjectId=-9999
where_clauses=[]
prev_index=0
allIds=0
allIdsList=[]
for index in reset_indices:    
    nDiaSources=sortedDF.iloc[index,1]
    diaObjectId=sortedDF.iloc[index,0]
    diaObjectIds=sortedDF.iloc[prev_index:index+1,0].tolist()
    where_clause = ','.join(map(str, diaObjectIds))
    where_clauses.append(where_clause)

    #  print(prev_index,index,len(diaObjectIds))
    allIdsList.extend(diaObjectIds)

    allIds+=len(diaObjectIds)
    prev_index=index+1
    
print(allIds)
print(len(results))
print(len(allIdsList))

#distinct_values = set(allIdsList)
# Count the number of distinct values
#distinct_count = len(distinct_values)
#print(distinct_count)


# ## the queries

# In[61]:


loop=0
for where_clause in where_clauses:
    loop+=1
    #print(loop)
    object_file=str(loop).zfill(6)+'_Object.csv'
    source_file=str(loop).zfill(6)+'_Source.csv'
    forced_file=str(loop).zfill(6)+'_Forced.csv'
    #print(loop,where_clause)
    # DiaObject query
    object_query='select o.* from dp02_dc2_catalogs.DiaObject as o where o.diaObjectId in ('+where_clause+')'
    #print(object_query)
    # DiaSource query
    source_query='select s.* from dp02_dc2_catalogs.DiaSource as s where s.diaObjectId in ('+where_clause+')'
    #print(source_query)
    # Forced query
    forced_query='select f.* from dp02_dc2_catalogs.ForcedSourceOnDiaObject as f where f.diaObjectId in ('+where_clause+')'
    #print(forced_query)
    

    if not os.path.exists(object_file):
        print(loop,"querying Object")
        object_results = service.search(object_query)
        object_results.to_table().write(object_file, format='csv',overwrite=True)
        
    if not os.path.exists(source_file):
        print(loop,"querying Source")
        source_results = service.search(source_query)
        source_results.to_table().write(source_file, format='csv',overwrite=True)

    if not os.path.exists(forced_file):
        print(loop,"querying Forced")
        forced_results = service.search(forced_query)
        forced_results.to_table().write(forced_file, format='csv',overwrite=True)
    if loop==500:
        break


# In[ ]:




