import { useSetRecoilState } from 'recoil';
import { useTranslation } from 'react-i18next';
import useRefreshSession from '~/src/Auth/hooks/useRefreshSession';
import useApiClient from '~/src/hooks/useApiClient';
import useProjectId from '~/src/hooks/useProjectId';
import { snackbarMessageSelector } from '~/src/Snackbar/store';
import { pendingState } from '../store';

const useStartTraining = () => {
  const { t } = useTranslation('projects');
  const [refreshSession] = useRefreshSession();
  const apiClient = useApiClient();
  const projectId = useProjectId();
  const setPending = useSetRecoilState(pendingState);
  const setSnackbarMessage = useSetRecoilState(snackbarMessageSelector);

  const startTraining = async (id?: number) => {
    setPending(true);
    try {
      await refreshSession();
      await apiClient.put(`/automodeler/${id || projectId}/analyze`);
    } catch (err) {
      setSnackbarMessage({
        message: t('productionFile.errors.startTraining'),
        severity: 'error',
      });
      throw err;
    } finally {
      setPending(false);
    }
  };

  const startPredicting = async (modelName?: string) => {
    setPending(true);
    try {
      await refreshSession();
      const params = {} as any;
      if (modelName) {
        params.modelName = modelName;
      }
      await apiClient.put(`/automodeler/${projectId}/predict`, null, { params });
    } catch (err) {
      setSnackbarMessage({
        message: t('productionFile.errors.startPredicting'),
        severity: 'error',
      });
      throw err;
    } finally {
      setPending(false);
    }
  };

  return [startTraining, startPredicting] as const;
};

export default useStartTraining;
